use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{Notify, RwLock};
use tracing::{debug, error, info, warn};
use crate::cert::CertificateAuthority;
use crate::config::TransparentProxyConfig;
use crate::events::Event;
use crate::events::connect::Connect;
use crate::plugins::registry::PluginRegistry;
use crate::proxy::tenant_resolver::TenantResolver;
use crate::proxy::{UpstreamClient, is_closed, parse_authority_host_port, run_tls_mitm};
use crate::tenant::TenantContext;
use super::netfilter::NetfilterManager;
pub struct TransparentProxy {
listen_addr: Option<SocketAddr>,
ca: Arc<CertificateAuthority>,
plugin_registry: Option<Arc<RwLock<PluginRegistry>>>,
tenant_resolver: Arc<dyn TenantResolver>,
upstream: UpstreamClient,
config: TransparentProxyConfig,
shutdown_notify: Arc<Notify>,
netfilter: Option<NetfilterManager>,
}
impl TransparentProxy {
pub fn new(
ca: Arc<CertificateAuthority>,
plugin_registry: Option<Arc<RwLock<PluginRegistry>>>,
tenant_resolver: Arc<dyn TenantResolver>,
upstream: UpstreamClient,
config: TransparentProxyConfig,
shutdown_notify: Arc<Notify>,
) -> Self {
Self {
listen_addr: None,
ca,
plugin_registry,
tenant_resolver,
upstream,
config,
shutdown_notify,
netfilter: None,
}
}
pub fn listen_addr(&self) -> Option<SocketAddr> {
self.listen_addr
}
pub async fn start(&mut self) -> anyhow::Result<()> {
let bind_addr: SocketAddr = self
.config
.listen_addr
.as_deref()
.unwrap_or("0.0.0.0:8080")
.parse()
.map_err(|e| anyhow::anyhow!("Invalid transparent proxy bind address: {}", e))?;
let listener = TcpListener::bind(bind_addr).await?;
self.listen_addr = Some(listener.local_addr()?);
info!(
"Transparent proxy listening on {}",
self.listen_addr.unwrap()
);
if self.config.auto_iptables {
let interface = self
.config
.interface
.clone()
.unwrap_or_else(|| "tailscale0".to_string());
let port = self.listen_addr.unwrap().port();
let mut nf = NetfilterManager::new(interface, port);
if let Err(e) = nf.setup() {
warn!("Failed to set up iptables rules: {}", e);
}
self.netfilter = Some(nf);
}
let shutdown = self.shutdown_notify.clone();
let ca = self.ca.clone();
let plugin_registry = self.plugin_registry.clone();
let tenant_resolver = self.tenant_resolver.clone();
let upstream = self.upstream.clone();
tokio::spawn(async move {
loop {
tokio::select! {
_ = shutdown.notified() => break,
accept_result = listener.accept() => {
match accept_result {
Ok((stream, peer)) => {
info!("Transparent: accepted connection from {}", peer);
let ca = ca.clone();
let plugin_registry = plugin_registry.clone();
let tenant_resolver = tenant_resolver.clone();
let upstream = upstream.clone();
tokio::spawn(async move {
let tenant_ctx = tenant_resolver.resolve(&peer).await;
if let Err(e) = handle_transparent_connection(
stream,
peer,
ca,
plugin_registry,
upstream,
tenant_ctx,
).await
&& !is_closed(&e) {
debug!("Transparent connection error from {}: {}", peer, e);
}
});
}
Err(e) => error!("Transparent accept error: {}", e),
}
}
}
}
});
Ok(())
}
}
pub fn extract_sni_from_client_hello(buf: &[u8]) -> Option<String> {
if buf.len() < 5 {
return None;
}
if buf[0] != 22 {
return None;
}
let record_len = ((buf[3] as usize) << 8) | (buf[4] as usize);
let handshake = &buf[5..];
if handshake.len() < record_len.min(handshake.len()) {
}
if handshake.is_empty() || handshake[0] != 1 {
return None;
}
if handshake.len() < 4 {
return None;
}
let ch = &handshake[4..];
if ch.len() < 34 {
return None;
}
let mut pos = 34;
if pos >= ch.len() {
return None;
}
let sid_len = ch[pos] as usize;
pos += 1 + sid_len;
if pos + 2 > ch.len() {
return None;
}
let cs_len = ((ch[pos] as usize) << 8) | (ch[pos + 1] as usize);
pos += 2 + cs_len;
if pos >= ch.len() {
return None;
}
let cm_len = ch[pos] as usize;
pos += 1 + cm_len;
if pos + 2 > ch.len() {
return None;
}
let ext_len = ((ch[pos] as usize) << 8) | (ch[pos + 1] as usize);
pos += 2;
let ext_end = pos + ext_len.min(ch.len() - pos);
while pos + 4 <= ext_end {
let ext_type = ((ch[pos] as u16) << 8) | (ch[pos + 1] as u16);
let ext_data_len = ((ch[pos + 2] as usize) << 8) | (ch[pos + 3] as usize);
pos += 4;
if ext_type == 0 {
if pos + ext_data_len > ext_end {
return None;
}
let sni_data = &ch[pos..pos + ext_data_len];
if sni_data.len() < 2 {
return None;
}
let mut sni_pos = 2; while sni_pos + 3 <= sni_data.len() {
let name_type = sni_data[sni_pos];
let name_len =
((sni_data[sni_pos + 1] as usize) << 8) | (sni_data[sni_pos + 2] as usize);
sni_pos += 3;
if name_type == 0 && sni_pos + name_len <= sni_data.len() {
return String::from_utf8(sni_data[sni_pos..sni_pos + name_len].to_vec()).ok();
}
sni_pos += name_len;
}
return None;
}
pos += ext_data_len;
}
None
}
async fn should_intercept(
plugin_registry: &Option<Arc<RwLock<PluginRegistry>>>,
hostname: &str,
) -> bool {
let Some(registry) = plugin_registry else {
return false;
};
let (host, port) = match parse_authority_host_port(hostname, 443) {
Ok(hp) => hp,
Err(_) => (hostname.to_string(), 443),
};
let connect_event: Box<dyn Event> = Box::new(Connect::new(host, port));
let registry = registry.read().await;
registry.can_handle(&*connect_event)
}
async fn handle_transparent_connection(
mut stream: TcpStream,
peer: SocketAddr,
ca: Arc<CertificateAuthority>,
plugin_registry: Option<Arc<RwLock<PluginRegistry>>>,
upstream: UpstreamClient,
_tenant_ctx: TenantContext,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let mut peek_buf = [0u8; 5];
let n = stream.peek(&mut peek_buf).await?;
if n == 0 {
return Ok(());
}
if peek_buf[0] == 22 {
let mut hello_buf = vec![0u8; 4096];
let n = stream.peek(&mut hello_buf).await?;
let hello_data = &hello_buf[..n];
let hostname = extract_sni_from_client_hello(hello_data).unwrap_or_else(|| {
warn!("Could not extract SNI from ClientHello from {}", peer);
"unknown".to_string()
});
info!("Transparent TLS: SNI={} from {}", hostname, peer);
if should_intercept(&plugin_registry, &hostname).await {
info!("Transparent: intercepting {} (plugins matched)", hostname);
let authority = format!("{}:443", hostname);
if let Err(e) = run_tls_mitm(upstream, stream, authority, ca, plugin_registry).await
&& !is_closed(&e) {
debug!("Transparent MITM error for {}: {}", hostname, e);
}
} else {
info!(
"Transparent: forwarding {} directly (no plugins matched)",
hostname
);
let mut upstream_stream = TcpStream::connect(format!("{}:443", hostname)).await?;
match tokio::io::copy_bidirectional(&mut stream, &mut upstream_stream).await {
Ok(_) => {}
Err(e) if is_closed(&e) => {}
Err(e) => debug!("Transparent forward error for {}: {}", hostname, e),
}
}
} else {
info!("Transparent HTTP: forwarding from {}", peer);
let mut buf = vec![0u8; 8192];
let n = stream.peek(&mut buf).await?;
let request_data = std::str::from_utf8(&buf[..n]).unwrap_or("");
let host = request_data
.lines()
.find(|l| l.to_lowercase().starts_with("host:"))
.and_then(|l| l.split_once(':').map(|(_, v)| v.trim().to_string()))
.unwrap_or_default();
if host.is_empty() {
debug!("Transparent HTTP: no Host header found, dropping");
return Ok(());
}
let mut upstream_stream = TcpStream::connect(format!("{}:80", host)).await?;
match tokio::io::copy_bidirectional(&mut stream, &mut upstream_stream).await {
Ok(_) => {}
Err(e) if is_closed(&e) => {}
Err(e) => debug!("Transparent HTTP forward error for {}: {}", host, e),
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_sni_from_real_client_hello() {
let hello = build_test_client_hello("example.com");
let sni = extract_sni_from_client_hello(&hello);
assert_eq!(sni.as_deref(), Some("example.com"));
}
#[test]
fn test_extract_sni_no_sni_extension() {
let hello = build_test_client_hello_no_sni();
let sni = extract_sni_from_client_hello(&hello);
assert!(sni.is_none());
}
#[test]
fn test_extract_sni_not_tls() {
let buf = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n";
let sni = extract_sni_from_client_hello(buf);
assert!(sni.is_none());
}
#[test]
fn test_extract_sni_empty() {
let sni = extract_sni_from_client_hello(&[]);
assert!(sni.is_none());
}
fn build_test_client_hello(hostname: &str) -> Vec<u8> {
let hostname_bytes = hostname.as_bytes();
let sni_name_len = hostname_bytes.len();
let sni_entry_len = 1 + 2 + sni_name_len; let sni_list_len = sni_entry_len;
let sni_ext_data_len = 2 + sni_list_len;
let ext_total = 4 + sni_ext_data_len;
let ch_body_len = 2 + 32 + 1 + 2 + 2 + 1 + 1 + 2 + ext_total;
let hs_len = 1 + 3 + ch_body_len;
let mut buf = Vec::with_capacity(5 + hs_len);
buf.push(22); buf.push(3);
buf.push(1); buf.push((hs_len >> 8) as u8);
buf.push((hs_len & 0xff) as u8);
buf.push(1); buf.push(0);
buf.push((ch_body_len >> 8) as u8);
buf.push((ch_body_len & 0xff) as u8);
buf.push(3);
buf.push(3); buf.extend_from_slice(&[0u8; 32]);
buf.push(0);
buf.push(0);
buf.push(2); buf.push(0x00);
buf.push(0xff);
buf.push(1); buf.push(0);
buf.push((ext_total >> 8) as u8);
buf.push((ext_total & 0xff) as u8);
buf.push(0);
buf.push(0); buf.push((sni_ext_data_len >> 8) as u8);
buf.push((sni_ext_data_len & 0xff) as u8);
buf.push((sni_list_len >> 8) as u8);
buf.push((sni_list_len & 0xff) as u8);
buf.push(0); buf.push((sni_name_len >> 8) as u8);
buf.push((sni_name_len & 0xff) as u8);
buf.extend_from_slice(hostname_bytes);
buf
}
fn build_test_client_hello_no_sni() -> Vec<u8> {
let ch_body_len = 2 + 32 + 1 + 2 + 2 + 1 + 1;
let hs_len = 1 + 3 + ch_body_len;
let mut buf = Vec::with_capacity(5 + hs_len);
buf.push(22);
buf.push(3);
buf.push(1);
buf.push((hs_len >> 8) as u8);
buf.push((hs_len & 0xff) as u8);
buf.push(1);
buf.push(0);
buf.push((ch_body_len >> 8) as u8);
buf.push((ch_body_len & 0xff) as u8);
buf.push(3);
buf.push(3);
buf.extend_from_slice(&[0u8; 32]);
buf.push(0);
buf.push(0);
buf.push(2);
buf.push(0x00);
buf.push(0xff);
buf.push(1);
buf.push(0);
buf
}
}