use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use bytes::Bytes;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use crate::policy::{EgressEvaluation, HostnameSource, NetworkPolicy, Protocol};
use crate::shared::SharedState;
use crate::tls::sni;
const SERVER_READ_BUF_SIZE: usize = 16384;
const PEEK_BUF_SIZE: usize = 16384;
const PEEK_BUDGET: Duration = Duration::from_secs(5);
pub fn spawn_tcp_proxy(
handle: &tokio::runtime::Handle,
guest_dst: SocketAddr,
connect_dst: SocketAddr,
from_smoltcp: mpsc::Receiver<Bytes>,
to_smoltcp: mpsc::Sender<Bytes>,
shared: Arc<SharedState>,
network_policy: Arc<NetworkPolicy>,
) {
handle.spawn(async move {
if let Err(e) = tcp_proxy_task(
guest_dst,
connect_dst,
from_smoltcp,
to_smoltcp,
shared,
network_policy,
)
.await
{
tracing::debug!(dst = %connect_dst, error = %e, "TCP proxy task ended");
}
});
}
async fn tcp_proxy_task(
guest_dst: SocketAddr,
connect_dst: SocketAddr,
mut from_smoltcp: mpsc::Receiver<Bytes>,
to_smoltcp: mpsc::Sender<Bytes>,
shared: Arc<SharedState>,
network_policy: Arc<NetworkPolicy>,
) -> io::Result<()> {
let (initial_buf, sni) = if network_policy.has_domain_rules() {
peek_for_sni(&mut from_smoltcp, PEEK_BUF_SIZE, PEEK_BUDGET).await
} else {
(Vec::new(), None)
};
if network_policy.has_domain_rules() {
let source = match sni.as_deref() {
Some(name) => HostnameSource::Sni(name),
None => HostnameSource::CacheOnly,
};
match network_policy.evaluate_egress_with_source(guest_dst, Protocol::Tcp, &shared, source)
{
EgressEvaluation::Allow => {}
EgressEvaluation::Deny => {
tracing::debug!(
dst = %guest_dst,
source = source.label(),
"TCP egress denied by domain policy",
);
return Ok(());
}
EgressEvaluation::DeferUntilHostname => {
debug_assert!(false, "DeferUntilHostname leaked into TCP proxy task");
return Ok(());
}
}
}
let stream = TcpStream::connect(connect_dst).await?;
let (mut server_rx, mut server_tx) = stream.into_split();
if !initial_buf.is_empty()
&& let Err(e) = server_tx.write_all(&initial_buf).await
{
tracing::debug!(dst = %connect_dst, error = %e, "replay of buffered first flight failed");
return Ok(());
}
let mut server_buf = vec![0u8; SERVER_READ_BUF_SIZE];
loop {
tokio::select! {
data = from_smoltcp.recv() => {
match data {
Some(bytes) => {
if let Err(e) = server_tx.write_all(&bytes).await {
tracing::debug!(dst = %connect_dst, error = %e, "write to server failed");
break;
}
}
None => break,
}
}
result = server_rx.read(&mut server_buf) => {
match result {
Ok(0) => break, Ok(n) => {
let data = Bytes::copy_from_slice(&server_buf[..n]);
if to_smoltcp.send(data).await.is_err() {
break;
}
shared.proxy_wake.wake();
}
Err(e) => {
tracing::debug!(dst = %connect_dst, error = %e, "read from server failed");
break;
}
}
}
}
}
Ok(())
}
async fn peek_for_sni(
rx: &mut mpsc::Receiver<Bytes>,
max: usize,
budget: Duration,
) -> (Vec<u8>, Option<String>) {
let mut buf = Vec::with_capacity(PEEK_BUF_SIZE.min(8192));
let timeout_fut = tokio::time::sleep(budget);
tokio::pin!(timeout_fut);
let raw_sni = loop {
tokio::select! {
biased;
_ = &mut timeout_fut => break None,
data = rx.recv() => {
match data {
Some(bytes) => {
buf.extend_from_slice(&bytes);
if buf.first() != Some(&0x16) {
break None;
}
if let Some(name) = sni::extract_sni(&buf) {
break Some(name);
}
if buf.len() >= max {
break None;
}
}
None => break None,
}
}
}
};
let canonical = raw_sni.map(|s| s.trim_end_matches('.').to_ascii_lowercase());
(buf, canonical)
}
#[cfg(test)]
mod tests {
use super::*;
fn synthetic_client_hello(sni: &str) -> Vec<u8> {
let host_bytes = sni.as_bytes();
let host_len = host_bytes.len() as u16;
let server_name_list_len = 3 + host_len; let extension_data_len = 2 + server_name_list_len; let extensions_total = 4 + extension_data_len;
let mut body = Vec::new();
body.extend_from_slice(&[0x03, 0x03]);
body.extend_from_slice(&[0u8; 32]);
body.push(0);
body.extend_from_slice(&[0x00, 0x02, 0x00, 0x2f]);
body.extend_from_slice(&[0x01, 0x00]);
body.extend_from_slice(&extensions_total.to_be_bytes());
body.extend_from_slice(&[0x00, 0x00]);
body.extend_from_slice(&extension_data_len.to_be_bytes());
body.extend_from_slice(&server_name_list_len.to_be_bytes());
body.push(0x00); body.extend_from_slice(&host_len.to_be_bytes());
body.extend_from_slice(host_bytes);
let handshake_len = body.len() as u32;
let mut hs = Vec::new();
hs.push(0x01); hs.extend_from_slice(&handshake_len.to_be_bytes()[1..]); hs.extend_from_slice(&body);
let record_len = hs.len() as u16;
let mut record = Vec::new();
record.extend_from_slice(&[0x16, 0x03, 0x01]); record.extend_from_slice(&record_len.to_be_bytes());
record.extend_from_slice(&hs);
record
}
#[tokio::test]
async fn peek_for_sni_extracts_and_canonicalizes() {
let (tx, mut rx) = mpsc::channel(4);
let hello = synthetic_client_hello("Example.COM");
tx.send(Bytes::from(hello.clone())).await.unwrap();
drop(tx);
let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
assert_eq!(sni.as_deref(), Some("example.com"));
assert_eq!(buf, hello);
}
#[tokio::test]
async fn peek_for_sni_returns_none_on_channel_close_without_data() {
let (tx, mut rx) = mpsc::channel::<Bytes>(1);
drop(tx);
let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
assert!(buf.is_empty());
assert_eq!(sni, None);
}
#[tokio::test]
async fn peek_for_sni_returns_none_on_non_tls_data() {
let (tx, mut rx) = mpsc::channel(4);
tx.send(Bytes::from_static(
b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n",
))
.await
.unwrap();
drop(tx);
let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
assert!(
!buf.is_empty(),
"buffered bytes must be returned for replay"
);
assert_eq!(sni, None);
}
#[tokio::test]
async fn peek_for_sni_falls_back_on_timeout() {
let (tx, mut rx) = mpsc::channel::<Bytes>(1);
let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, Duration::from_millis(50)).await;
drop(tx);
assert!(buf.is_empty());
assert_eq!(sni, None);
}
#[tokio::test]
async fn peek_for_sni_caps_at_max_bytes() {
let (tx, mut rx) = mpsc::channel(4);
let mut first = vec![0u8; 8192];
first[0] = 0x16;
tx.send(Bytes::from(first)).await.unwrap();
tx.send(Bytes::from(vec![0u8; 8192])).await.unwrap();
tx.send(Bytes::from(vec![0u8; 8192])).await.unwrap();
drop(tx);
let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
assert_eq!(sni, None, "no SNI in non-TLS data");
assert!(
buf.len() >= PEEK_BUF_SIZE,
"buffer must hit the cap before bail-out: got {}",
buf.len()
);
}
#[tokio::test]
async fn peek_for_sni_bails_immediately_on_non_tls_first_byte() {
let (tx, mut rx) = mpsc::channel(4);
tx.send(Bytes::from_static(b"GET / HTTP/1.1\r\nHost: x\r\n\r\n"))
.await
.unwrap();
drop(tx);
let started = std::time::Instant::now();
let (buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
let elapsed = started.elapsed();
assert_eq!(sni, None);
assert!(buf.starts_with(b"GET"));
assert!(
elapsed < Duration::from_millis(500),
"non-TLS bail must be fast: took {elapsed:?}"
);
}
use std::net::IpAddr;
use std::time::Duration as StdDuration;
use crate::policy::{Action, Destination, NetworkPolicy, PortRange, Rule};
use crate::shared::{ResolvedHostnameFamily, SharedState};
const SHARED_FASTLY_IP: &str = "151.101.0.223";
fn shared_with(host: &str, ip: &str) -> SharedState {
let shared = SharedState::new(4);
shared.cache_resolved_hostname(
host,
ResolvedHostnameFamily::Ipv4,
[ip.parse::<IpAddr>().unwrap()],
StdDuration::from_secs(60),
);
shared
}
fn allow_https(domain: &str) -> Rule {
Rule {
direction: crate::policy::Direction::Egress,
destination: Destination::Domain(domain.parse().unwrap()),
protocols: vec![Protocol::Tcp],
ports: vec![PortRange::single(443)],
action: Action::Allow,
}
}
#[tokio::test]
async fn integration_sni_overrides_cache_for_over_allow() {
let shared = shared_with("pypi.org", SHARED_FASTLY_IP);
let policy = NetworkPolicy {
default_egress: Action::Deny,
default_ingress: Action::Allow,
rules: vec![allow_https("pypi.org")],
};
let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
let (tx, mut rx) = mpsc::channel(4);
tx.send(Bytes::from(synthetic_client_hello("evil.com")))
.await
.unwrap();
drop(tx);
let (initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
assert_eq!(sni.as_deref(), Some("evil.com"));
assert!(!initial_buf.is_empty());
let source = sni
.as_deref()
.map(HostnameSource::Sni)
.unwrap_or(HostnameSource::CacheOnly);
let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
assert_eq!(
eval,
EgressEvaluation::Deny,
"SNI=evil.com must not piggy-back on the cached pypi.org match",
);
}
#[tokio::test]
async fn integration_sni_overrides_cache_for_over_block() {
let shared = shared_with("ads.example.com", SHARED_FASTLY_IP);
let policy = NetworkPolicy {
default_egress: Action::Allow,
default_ingress: Action::Allow,
rules: vec![Rule::deny_egress(Destination::Domain(
"ads.example.com".parse().unwrap(),
))],
};
let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
let (tx, mut rx) = mpsc::channel(4);
tx.send(Bytes::from(synthetic_client_hello("api.example.com")))
.await
.unwrap();
drop(tx);
let (_initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
assert_eq!(sni.as_deref(), Some("api.example.com"));
let source = sni
.as_deref()
.map(HostnameSource::Sni)
.unwrap_or(HostnameSource::CacheOnly);
let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
assert_eq!(
eval,
EgressEvaluation::Allow,
"SNI=api.example.com must not be caught by the deny on ads.example.com",
);
}
#[tokio::test]
async fn integration_non_tls_falls_back_to_cache() {
let shared = shared_with("pypi.org", SHARED_FASTLY_IP);
let policy = NetworkPolicy {
default_egress: Action::Deny,
default_ingress: Action::Allow,
rules: vec![allow_https("pypi.org")],
};
let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
let (tx, mut rx) = mpsc::channel(4);
tx.send(Bytes::from_static(
b"GET / HTTP/1.1\r\nHost: pypi.org\r\n\r\n",
))
.await
.unwrap();
drop(tx);
let (initial_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
assert_eq!(sni, None, "non-TLS data → no SNI");
assert!(
!initial_buf.is_empty(),
"buffered bytes must survive for replay"
);
let source = sni
.as_deref()
.map(HostnameSource::Sni)
.unwrap_or(HostnameSource::CacheOnly);
let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
assert_eq!(
eval,
EgressEvaluation::Allow,
"cache-only fallback must still allow the cached hostname's IP",
);
}
#[tokio::test]
async fn integration_sni_matches_domain_suffix_with_cache_binding() {
let shared = shared_with("files.pythonhosted.org", SHARED_FASTLY_IP);
let policy = NetworkPolicy {
default_egress: Action::Deny,
default_ingress: Action::Allow,
rules: vec![Rule {
direction: crate::policy::Direction::Egress,
destination: Destination::DomainSuffix(".pythonhosted.org".parse().unwrap()),
protocols: vec![Protocol::Tcp],
ports: vec![PortRange::single(443)],
action: Action::Allow,
}],
};
let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
let (tx, mut rx) = mpsc::channel(4);
tx.send(Bytes::from(synthetic_client_hello(
"files.pythonhosted.org",
)))
.await
.unwrap();
drop(tx);
let (_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
let source = sni
.as_deref()
.map(HostnameSource::Sni)
.unwrap_or(HostnameSource::CacheOnly);
let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
assert_eq!(eval, EgressEvaluation::Allow);
}
#[tokio::test]
async fn integration_sni_denies_domain_suffix_without_cache_binding() {
let shared = SharedState::new(4); let policy = NetworkPolicy {
default_egress: Action::Deny,
default_ingress: Action::Allow,
rules: vec![Rule {
direction: crate::policy::Direction::Egress,
destination: Destination::DomainSuffix(".pythonhosted.org".parse().unwrap()),
protocols: vec![Protocol::Tcp],
ports: vec![PortRange::single(443)],
action: Action::Allow,
}],
};
let dst = SocketAddr::new(SHARED_FASTLY_IP.parse().unwrap(), 443);
let (tx, mut rx) = mpsc::channel(4);
tx.send(Bytes::from(synthetic_client_hello(
"files.pythonhosted.org",
)))
.await
.unwrap();
drop(tx);
let (_buf, sni) = peek_for_sni(&mut rx, PEEK_BUF_SIZE, PEEK_BUDGET).await;
let source = sni
.as_deref()
.map(HostnameSource::Sni)
.unwrap_or(HostnameSource::CacheOnly);
let eval = policy.evaluate_egress_with_source(dst, Protocol::Tcp, &shared, source);
assert_eq!(eval, EgressEvaluation::Deny);
}
}