1use std::io;
9use std::net::SocketAddr;
10use std::sync::Arc;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::time::Duration;
13
14use bytes::Bytes;
15use tokio::io::{AsyncReadExt, AsyncWriteExt};
16use tokio::net::TcpStream;
17use tokio::sync::mpsc;
18
19use crate::policy::{EgressEvaluation, HostnameSource, NetworkPolicy, Protocol};
20use crate::shared::SharedState;
21use crate::tls::sni;
22
23const SERVER_READ_BUF_SIZE: usize = 16384;
29
30const PEEK_BUF_SIZE: usize = 16384;
32
33const PEEK_BUDGET: Duration = Duration::from_secs(5);
36
37#[allow(clippy::too_many_arguments)]
53pub fn spawn_tcp_proxy(
54 handle: &tokio::runtime::Handle,
55 guest_dst: SocketAddr,
56 connect_dst: SocketAddr,
57 from_smoltcp: mpsc::Receiver<Bytes>,
58 to_smoltcp: mpsc::Sender<Bytes>,
59 shared: Arc<SharedState>,
60 network_policy: Arc<NetworkPolicy>,
61 upstream_connected: Arc<AtomicBool>,
62) {
63 handle.spawn(async move {
64 if let Err(e) = tcp_proxy_task(
65 guest_dst,
66 connect_dst,
67 from_smoltcp,
68 to_smoltcp,
69 shared,
70 network_policy,
71 upstream_connected,
72 )
73 .await
74 {
75 tracing::debug!(dst = %connect_dst, error = %e, "TCP proxy task ended");
76 }
77 });
78}
79
80async fn tcp_proxy_task(
83 guest_dst: SocketAddr,
84 connect_dst: SocketAddr,
85 mut from_smoltcp: mpsc::Receiver<Bytes>,
86 to_smoltcp: mpsc::Sender<Bytes>,
87 shared: Arc<SharedState>,
88 network_policy: Arc<NetworkPolicy>,
89 upstream_connected: Arc<AtomicBool>,
90) -> io::Result<()> {
91 let (initial_buf, sni) = if network_policy.has_domain_rules() {
95 peek_for_sni(&mut from_smoltcp, PEEK_BUF_SIZE, PEEK_BUDGET).await
96 } else {
97 (Vec::new(), None)
98 };
99
100 if network_policy.has_domain_rules() {
106 let source = match sni.as_deref() {
107 Some(name) => HostnameSource::Sni(name),
108 None => HostnameSource::CacheOnly,
109 };
110 match network_policy.evaluate_egress_with_source(guest_dst, Protocol::Tcp, &shared, source)
111 {
112 EgressEvaluation::Allow => {}
113 EgressEvaluation::Deny => {
114 tracing::debug!(
115 dst = %guest_dst,
116 source = source.label(),
117 "TCP egress denied by domain policy",
118 );
119 return Ok(());
120 }
121 EgressEvaluation::DeferUntilHostname => {
122 debug_assert!(false, "DeferUntilHostname leaked into TCP proxy task");
123 return Ok(());
124 }
125 }
126 }
127
128 let stream = TcpStream::connect(connect_dst).await?;
129 upstream_connected.store(true, Ordering::Release);
130 let (mut server_rx, mut server_tx) = stream.into_split();
131
132 if !initial_buf.is_empty()
134 && let Err(e) = server_tx.write_all(&initial_buf).await
135 {
136 tracing::debug!(dst = %connect_dst, error = %e, "replay of buffered first flight failed");
137 return Ok(());
138 }
139
140 let mut server_buf = vec![0u8; SERVER_READ_BUF_SIZE];
141
142 loop {
147 tokio::select! {
148 data = from_smoltcp.recv() => {
150 match data {
151 Some(bytes) => {
152 if let Err(e) = server_tx.write_all(&bytes).await {
153 tracing::debug!(dst = %connect_dst, error = %e, "write to server failed");
154 break;
155 }
156 }
157 None => break,
159 }
160 }
161
162 result = server_rx.read(&mut server_buf) => {
164 match result {
165 Ok(0) => break, Ok(n) => {
167 let data = Bytes::copy_from_slice(&server_buf[..n]);
168 if to_smoltcp.send(data).await.is_err() {
169 break;
171 }
172 shared.proxy_wake.wake();
175 }
176 Err(e) => {
177 tracing::debug!(dst = %connect_dst, error = %e, "read from server failed");
178 break;
179 }
180 }
181 }
182 }
183 }
184
185 Ok(())
186}
187
188async fn peek_for_sni(
198 rx: &mut mpsc::Receiver<Bytes>,
199 max: usize,
200 budget: Duration,
201) -> (Vec<u8>, Option<String>) {
202 let mut buf = Vec::with_capacity(PEEK_BUF_SIZE.min(8192));
203 let timeout_fut = tokio::time::sleep(budget);
204 tokio::pin!(timeout_fut);
205
206 let raw_sni = loop {
207 tokio::select! {
208 biased;
209 _ = &mut timeout_fut => break None,
210 data = rx.recv() => {
211 match data {
212 Some(bytes) => {
213 buf.extend_from_slice(&bytes);
214 if buf.first() != Some(&0x16) {
219 break None;
220 }
221 if let Some(name) = sni::extract_sni(&buf) {
222 break Some(name);
223 }
224 if buf.len() >= max {
225 break None;
226 }
227 }
228 None => break None,
229 }
230 }
231 }
232 };
233
234 let canonical = raw_sni.map(|s| s.trim_end_matches('.').to_ascii_lowercase());
235 (buf, canonical)
236}
237
238#[cfg(test)]
243mod tests {
244 use super::*;
245
246 fn synthetic_client_hello(sni: &str) -> Vec<u8> {
250 let host_bytes = sni.as_bytes();
253 let host_len = host_bytes.len() as u16;
254 let server_name_list_len = 3 + host_len; let extension_data_len = 2 + server_name_list_len; let extensions_total = 4 + extension_data_len; let mut body = Vec::new();
259 body.extend_from_slice(&[0x03, 0x03]);
261 body.extend_from_slice(&[0u8; 32]);
263 body.push(0);
265 body.extend_from_slice(&[0x00, 0x02, 0x00, 0x2f]);
267 body.extend_from_slice(&[0x01, 0x00]);
269 body.extend_from_slice(&extensions_total.to_be_bytes());
271 body.extend_from_slice(&[0x00, 0x00]);
273 body.extend_from_slice(&extension_data_len.to_be_bytes());
274 body.extend_from_slice(&server_name_list_len.to_be_bytes());
275 body.push(0x00); body.extend_from_slice(&host_len.to_be_bytes());
277 body.extend_from_slice(host_bytes);
278
279 let handshake_len = body.len() as u32;
280 let mut hs = Vec::new();
281 hs.push(0x01); hs.extend_from_slice(&handshake_len.to_be_bytes()[1..]); hs.extend_from_slice(&body);
284
285 let record_len = hs.len() as u16;
286 let mut record = Vec::new();
287 record.extend_from_slice(&[0x16, 0x03, 0x01]); record.extend_from_slice(&record_len.to_be_bytes());
289 record.extend_from_slice(&hs);
290
291 record
292 }
293
294 #[tokio::test]
295 async fn peek_for_sni_extracts_and_canonicalizes() {
296 let (tx, mut rx) = mpsc::channel(4);
297 let hello = synthetic_client_hello("Example.COM");
298 tx.send(Bytes::from(hello.clone())).await.unwrap();
299 drop(tx); let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
302 assert_eq!(sni.as_deref(), Some("example.com"));
303 assert_eq!(buf, hello);
304 }
305
306 #[tokio::test]
307 async fn peek_for_sni_returns_none_on_channel_close_without_data() {
308 let (tx, mut rx) = mpsc::channel::<Bytes>(1);
309 drop(tx);
310 let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
311 assert!(buf.is_empty());
312 assert_eq!(sni, None);
313 }
314
315 #[tokio::test]
316 async fn peek_for_sni_returns_none_on_non_tls_data() {
317 let (tx, mut rx) = mpsc::channel(4);
318 tx.send(Bytes::from_static(
320 b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n",
321 ))
322 .await
323 .unwrap();
324 drop(tx);
325 let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
326 assert!(
327 !buf.is_empty(),
328 "buffered bytes must be returned for replay"
329 );
330 assert_eq!(sni, None);
331 }
332
333 #[tokio::test]
334 async fn peek_for_sni_falls_back_on_timeout() {
335 let (tx, mut rx) = mpsc::channel::<Bytes>(1);
336 let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, Duration::from_millis(50)).await;
338 drop(tx);
339 assert!(buf.is_empty());
340 assert_eq!(sni, None);
341 }
342
343 #[tokio::test]
344 async fn peek_for_sni_caps_at_max_bytes() {
345 let (tx, mut rx) = mpsc::channel(4);
346 let mut first = vec![0u8; 8192];
350 first[0] = 0x16;
351 tx.send(Bytes::from(first)).await.unwrap();
352 tx.send(Bytes::from(vec![0u8; 8192])).await.unwrap();
353 tx.send(Bytes::from(vec![0u8; 8192])).await.unwrap();
354 drop(tx);
355
356 let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
357 assert_eq!(sni, None, "no SNI in non-TLS data");
358 assert!(
359 buf.len() >= PEEK_BUF_SIZE,
360 "buffer must hit the cap before bail-out: got {}",
361 buf.len()
362 );
363 }
364
365 #[tokio::test]
366 async fn peek_for_sni_bails_immediately_on_non_tls_first_byte() {
367 let (tx, mut rx) = mpsc::channel(4);
368 tx.send(Bytes::from_static(b"GET / HTTP/1.1\r\nHost: x\r\n\r\n"))
370 .await
371 .unwrap();
372 drop(tx);
373
374 let started = std::time::Instant::now();
377 let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
378 let elapsed = started.elapsed();
379 assert_eq!(sni, None);
380 assert!(buf.starts_with(b"GET"));
381 assert!(
382 elapsed < Duration::from_millis(500),
383 "non-TLS bail must be fast: took {elapsed:?}"
384 );
385 }
386
387 use std::net::IpAddr;
392 use std::time::Duration as StdDuration;
393
394 use crate::policy::{Action, Destination, NetworkPolicy, PortRange, Rule};
395 use crate::shared::{ResolvedHostnameFamily, SharedState};
396
397 const SHARED_FASTLY_IP: &str = "151.101.0.223";
398
399 fn shared_with(host: &str, ip: &str) -> SharedState {
400 let shared = SharedState::new(4);
401 shared.cache_resolved_hostname(
402 host,
403 ResolvedHostnameFamily::Ipv4,
404 [ip.parse::<IpAddr>().unwrap()],
405 StdDuration::from_secs(60),
406 );
407 shared
408 }
409
410 fn allow_https(domain: &str) -> Rule {
411 Rule {
412 direction: crate::policy::Direction::Egress,
413 destination: Destination::Domain(domain.parse().unwrap()),
414 protocols: vec![Protocol::Tcp],
415 ports: vec![PortRange::single(443)],
416 action: Action::Allow,
417 }
418 }
419
420 #[tokio::test]
423 async fn integration_sni_overrides_cache_for_over_allow() {
424 let shared = shared_with("pypi.org", SHARED_FASTLY_IP);
425 let policy = NetworkPolicy {
426 default_egress: Action::Deny,
427 default_ingress: Action::Allow,
428 rules: vec![allow_https("pypi.org")],
429 };
430 let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
431
432 let (tx, mut rx) = mpsc::channel(4);
433 tx.send(Bytes::from(synthetic_client_hello("evil.com")))
434 .await
435 .unwrap();
436 drop(tx);
437
438 let (initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
439 assert_eq!(sni.as_deref(), Some("evil.com"));
440 assert!(!initial_buf.is_empty());
441
442 let source = sni
443 .as_deref()
444 .map(HostnameSource::Sni)
445 .unwrap_or(HostnameSource::CacheOnly);
446 let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
447 assert_eq!(
448 eval,
449 EgressEvaluation::Deny,
450 "SNI=evil.com must not piggy-back on the cached pypi.org match",
451 );
452 }
453
454 #[tokio::test]
457 async fn integration_sni_overrides_cache_for_over_block() {
458 let shared = shared_with("ads.example.com", SHARED_FASTLY_IP);
459 let policy = NetworkPolicy {
460 default_egress: Action::Allow,
461 default_ingress: Action::Allow,
462 rules: vec![Rule::deny_egress(Destination::Domain(
463 "ads.example.com".parse().unwrap(),
464 ))],
465 };
466 let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
467
468 let (tx, mut rx) = mpsc::channel(4);
469 tx.send(Bytes::from(synthetic_client_hello("api.example.com")))
470 .await
471 .unwrap();
472 drop(tx);
473
474 let (_initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
475 assert_eq!(sni.as_deref(), Some("api.example.com"));
476
477 let source = sni
478 .as_deref()
479 .map(HostnameSource::Sni)
480 .unwrap_or(HostnameSource::CacheOnly);
481 let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
482 assert_eq!(
483 eval,
484 EgressEvaluation::Allow,
485 "SNI=api.example.com must not be caught by the deny on ads.example.com",
486 );
487 }
488
489 #[tokio::test]
492 async fn integration_non_tls_falls_back_to_cache() {
493 let shared = shared_with("pypi.org", SHARED_FASTLY_IP);
494 let policy = NetworkPolicy {
495 default_egress: Action::Deny,
496 default_ingress: Action::Allow,
497 rules: vec![allow_https("pypi.org")],
498 };
499 let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
500
501 let (tx, mut rx) = mpsc::channel(4);
502 tx.send(Bytes::from_static(
504 b"GET / HTTP/1.1\r\nHost: pypi.org\r\n\r\n",
505 ))
506 .await
507 .unwrap();
508 drop(tx);
509
510 let (initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
511 assert_eq!(sni, None, "non-TLS data → no SNI");
512 assert!(
513 !initial_buf.is_empty(),
514 "buffered bytes must survive for replay"
515 );
516
517 let source = sni
518 .as_deref()
519 .map(HostnameSource::Sni)
520 .unwrap_or(HostnameSource::CacheOnly);
521 let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
522 assert_eq!(
523 eval,
524 EgressEvaluation::Allow,
525 "cache-only fallback must still allow the cached hostname's IP",
526 );
527 }
528
529 #[tokio::test]
532 async fn integration_sni_matches_domain_suffix_with_cache_binding() {
533 let shared = shared_with("files.pythonhosted.org", SHARED_FASTLY_IP);
534 let policy = NetworkPolicy {
535 default_egress: Action::Deny,
536 default_ingress: Action::Allow,
537 rules: vec![Rule {
538 direction: crate::policy::Direction::Egress,
539 destination: Destination::DomainSuffix(".pythonhosted.org".parse().unwrap()),
540 protocols: vec![Protocol::Tcp],
541 ports: vec![PortRange::single(443)],
542 action: Action::Allow,
543 }],
544 };
545 let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
546
547 let (tx, mut rx) = mpsc::channel(4);
548 tx.send(Bytes::from(synthetic_client_hello(
549 "files.pythonhosted.org",
550 )))
551 .await
552 .unwrap();
553 drop(tx);
554
555 let (_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
556 let source = sni
557 .as_deref()
558 .map(HostnameSource::Sni)
559 .unwrap_or(HostnameSource::CacheOnly);
560 let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
561 assert_eq!(eval, EgressEvaluation::Allow);
562 }
563
564 #[tokio::test]
569 async fn integration_sni_denies_domain_suffix_without_cache_binding() {
570 let shared = SharedState::new(4); let policy = NetworkPolicy {
572 default_egress: Action::Deny,
573 default_ingress: Action::Allow,
574 rules: vec![Rule {
575 direction: crate::policy::Direction::Egress,
576 destination: Destination::DomainSuffix(".pythonhosted.org".parse().unwrap()),
577 protocols: vec![Protocol::Tcp],
578 ports: vec![PortRange::single(443)],
579 action: Action::Allow,
580 }],
581 };
582 let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
583
584 let (tx, mut rx) = mpsc::channel(4);
585 tx.send(Bytes::from(synthetic_client_hello(
586 "files.pythonhosted.org",
587 )))
588 .await
589 .unwrap();
590 drop(tx);
591
592 let (_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
593 let source = sni
594 .as_deref()
595 .map(HostnameSource::Sni)
596 .unwrap_or(HostnameSource::CacheOnly);
597 let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
598 assert_eq!(eval, EgressEvaluation::Deny);
599 }
600}