use std::{
collections::HashMap,
fs::File,
io::{Read, Write},
net::{IpAddr, Ipv4Addr},
os::unix::io::AsRawFd,
sync::Arc,
thread,
};
use anyhow::{Context, Result};
use rand::{rngs::OsRng, RngCore};
use smoltcp::{
iface::{Config, Interface, SocketHandle, SocketSet},
phy::{wait as phy_wait, Device, DeviceCapabilities, Medium, RxToken, TxToken},
socket::tcp::{self as tcp, Socket as TcpSocket},
time::{Duration as SmolDuration, Instant as SmolInstant},
wire::{
HardwareAddress, IpAddress, IpCidr, IpEndpoint, IpListenEndpoint, IpProtocol, Ipv4Address,
Ipv4Packet, TcpPacket, UdpPacket,
},
};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
runtime::Handle,
sync::mpsc,
};
use tracing::{debug, warn};
use crate::{
chain::ChainEngine,
dns::DnsMap,
proxy::{BoxStream, Target},
};
pub use crate::chain::ChainConfig;
struct IpDevice {
rx: Option<Vec<u8>>,
tx: Vec<Vec<u8>>,
mtu: usize,
}
impl IpDevice {
fn new(mtu: usize) -> Self {
Self {
rx: None,
tx: Vec::new(),
mtu,
}
}
}
struct OwnedRxToken(Vec<u8>);
impl RxToken for OwnedRxToken {
fn consume<R, F: FnOnce(&mut [u8]) -> R>(mut self, f: F) -> R {
f(&mut self.0)
}
}
struct OwnedTxToken<'a>(&'a mut Vec<Vec<u8>>);
impl<'a> TxToken for OwnedTxToken<'a> {
fn consume<R, F: FnOnce(&mut [u8]) -> R>(self, len: usize, f: F) -> R {
let mut buf = vec![0u8; len];
let r = f(&mut buf);
self.0.push(buf);
r
}
}
impl Device for IpDevice {
type RxToken<'a> = OwnedRxToken;
type TxToken<'a> = OwnedTxToken<'a>;
fn receive(&mut self, _: SmolInstant) -> Option<(OwnedRxToken, OwnedTxToken<'_>)> {
self.rx
.take()
.map(|p| (OwnedRxToken(p), OwnedTxToken(&mut self.tx)))
}
fn transmit(&mut self, _: SmolInstant) -> Option<OwnedTxToken<'_>> {
Some(OwnedTxToken(&mut self.tx))
}
fn capabilities(&self) -> DeviceCapabilities {
let mut caps = DeviceCapabilities::default();
caps.medium = Medium::Ip;
caps.max_transmission_unit = self.mtu;
caps
}
}
#[derive(Debug, Clone)]
pub struct TunnelConfig {
pub chain: ChainConfig,
pub dns_map: DnsMap,
pub tun_ip: Ipv4Addr,
pub prefix_len: u8,
pub dns_ip: Option<Ipv4Addr>,
}
struct FlowInfo {
to_relay: mpsc::Sender<Vec<u8>>,
from_relay: mpsc::Receiver<Vec<u8>>,
}
pub struct ProxyChainTunnel {
config: TunnelConfig,
}
impl ProxyChainTunnel {
pub fn new(config: TunnelConfig) -> Self {
Self { config }
}
pub fn spawn(self, tun_file: File, rt: Handle) -> thread::JoinHandle<Result<()>> {
thread::spawn(move || poll_loop(self.config, tun_file, rt))
}
pub async fn run(self, tun_file: File) -> Result<()> {
let rt = Handle::current();
let config = self.config;
let join = thread::spawn(move || poll_loop(config, tun_file, rt));
tokio::task::spawn_blocking(move || join.join())
.await
.context("tunnel thread panicked")?
.map_err(|_| anyhow::anyhow!("tunnel thread panicked"))
.and_then(|r| r)
}
}
const CHANNEL_CAPACITY: usize = 64;
const SOCKET_BUF: usize = 64 * 1024;
const MAX_POLL_WAIT_MS: u64 = 5;
fn poll_loop(config: TunnelConfig, mut tun_file: File, rt: Handle) -> Result<()> {
let mtu = 1500usize;
let tun_raw_fd = tun_file.as_raw_fd();
let mut device = IpDevice::new(mtu);
let mut iface_cfg = Config::new(HardwareAddress::Ip);
iface_cfg.random_seed = OsRng.next_u64();
let mut iface = Interface::new(iface_cfg, &mut device, SmolInstant::now());
iface.set_any_ip(true);
iface.update_ip_addrs(|addrs| {
let cidr = IpCidr::new(IpAddress::Ipv4(config.tun_ip.into()), config.prefix_len);
let _ = addrs.push(cidr);
});
iface
.routes_mut()
.add_default_ipv4_route(Ipv4Address::from(config.tun_ip))
.expect("smoltcp route table is full");
let mut sockets = SocketSet::new(vec![]);
let engine = Arc::new(ChainEngine::new(config.chain));
let dns_map = config.dns_map;
let dns_ip = config.dns_ip;
let mut pending: HashMap<(IpEndpoint, IpEndpoint), SocketHandle> = HashMap::new();
let mut active: HashMap<SocketHandle, FlowInfo> = HashMap::new();
let mut read_buf = vec![0u8; mtu + 4];
loop {
let delay = iface
.poll_delay(SmolInstant::now(), &sockets)
.map(|d| d.min(SmolDuration::from_millis(MAX_POLL_WAIT_MS)));
phy_wait(tun_raw_fd, delay).ok();
let maybe_packet = match tun_file.read(&mut read_buf) {
Ok(n) if n > 0 => Some(read_buf[..n].to_vec()),
_ => None,
};
if let Some(pkt) = maybe_packet {
if let Some(dns_ip) = dns_ip {
if let Some(resp) = try_handle_dns(&pkt, dns_ip, &dns_map) {
let _ = tun_file.write_all(&resp);
continue;
}
}
if let Some((src, dst)) = extract_tcp_syn(&pkt) {
if let std::collections::hash_map::Entry::Vacant(e) = pending.entry((src, dst)) {
let listen_dst = e.key().1;
let mut sock = TcpSocket::new(
tcp::SocketBuffer::new(vec![0u8; SOCKET_BUF]),
tcp::SocketBuffer::new(vec![0u8; SOCKET_BUF]),
);
let listen_ep = IpListenEndpoint {
addr: Some(listen_dst.addr),
port: listen_dst.port,
};
if sock.listen(listen_ep).is_ok() {
let handle = sockets.add(sock);
e.insert(handle);
debug!("new TCP flow {src} → {listen_dst}");
}
}
}
device.rx = Some(pkt);
}
iface.poll(SmolInstant::now(), &mut device, &mut sockets);
for pkt in device.tx.drain(..) {
let _ = tun_file.write_all(&pkt);
}
pending.retain(|(_src, dst), &mut handle| {
let sock = sockets.get_mut::<TcpSocket>(handle);
if !sock.may_send() {
return true;
}
let target = {
let ip: IpAddr = match dst.addr {
IpAddress::Ipv4(a) => IpAddr::V4(Ipv4Addr::from(a)),
IpAddress::Ipv6(a) => IpAddr::V6(a.into()),
};
let hostname = if let IpAddr::V4(v4) = ip {
dns_map.lookup_hostname(v4)
} else {
None
};
match hostname {
Some(h) => Target::Host(h, dst.port),
None => Target::Ip(ip, dst.port),
}
};
let (to_relay_tx, to_relay_rx) = mpsc::channel::<Vec<u8>>(CHANNEL_CAPACITY);
let (from_relay_tx, from_relay_rx) = mpsc::channel::<Vec<u8>>(CHANNEL_CAPACITY);
let eng = engine.clone();
rt.spawn(relay_flow(eng, target, to_relay_rx, from_relay_tx));
active.insert(
handle,
FlowInfo {
to_relay: to_relay_tx,
from_relay: from_relay_rx,
},
);
false });
active.retain(|&handle, info| {
let sock = sockets.get_mut::<TcpSocket>(handle);
if sock.can_recv() {
if let Ok(data) = sock.recv(|b| (b.len(), b.to_vec())) {
if !data.is_empty() {
let _ = info.to_relay.try_send(data);
}
}
}
while let Ok(data) = info.from_relay.try_recv() {
let _ = sock.send_slice(&data);
}
if matches!(
sock.state(),
tcp::State::Closed | tcp::State::CloseWait | tcp::State::TimeWait
) && !sock.can_recv()
{
sockets.remove(handle);
return false;
}
true
});
}
}
async fn relay_flow(
engine: Arc<ChainEngine>,
target: Target,
mut smol_rx: mpsc::Receiver<Vec<u8>>,
smol_tx: mpsc::Sender<Vec<u8>>,
) {
let proxy: BoxStream = match engine.connect(target.clone()).await {
Ok(s) => s,
Err(e) => {
warn!("relay: failed to connect to {target}: {e:#}");
return;
}
};
let (mut proxy_r, mut proxy_w) = tokio::io::split(proxy);
let mut buf = vec![0u8; 8192];
loop {
tokio::select! {
data = smol_rx.recv() => {
match data {
Some(d) => { if proxy_w.write_all(&d).await.is_err() { break; } }
None => break,
}
}
n = proxy_r.read(&mut buf) => {
match n {
Ok(0) | Err(_) => break,
Ok(n) => {
if smol_tx.send(buf[..n].to_vec()).await.is_err() {
break;
}
}
}
}
}
}
debug!("relay flow ended for {target}");
}
fn try_handle_dns(packet: &[u8], dns_ip: Ipv4Addr, dns_map: &DnsMap) -> Option<Vec<u8>> {
let ipv4 = Ipv4Packet::new_checked(packet).ok()?;
if ipv4.next_header() != IpProtocol::Udp {
return None;
}
if Ipv4Addr::from(ipv4.dst_addr()) != dns_ip {
return None;
}
let udp = UdpPacket::new_checked(ipv4.payload()).ok()?;
if udp.dst_port() != 53 {
return None;
}
let dns_resp = dns_map.handle_dns_query(udp.payload())?;
let udp_len = 8u16 + dns_resp.len() as u16;
let total_len = 20u16 + udp_len;
let mut pkt = vec![0u8; total_len as usize];
pkt[0] = 0x45; pkt[2] = (total_len >> 8) as u8;
pkt[3] = total_len as u8;
pkt[8] = 64; pkt[9] = 17; pkt[12..16].copy_from_slice(&ipv4.dst_addr().0); pkt[16..20].copy_from_slice(&ipv4.src_addr().0); let csum = ip_checksum(&pkt[0..20]);
pkt[10] = (csum >> 8) as u8;
pkt[11] = csum as u8;
pkt[20] = 0;
pkt[21] = 53; pkt[22] = (udp.src_port() >> 8) as u8;
pkt[23] = udp.src_port() as u8; pkt[24] = (udp_len >> 8) as u8;
pkt[25] = udp_len as u8;
pkt[28..].copy_from_slice(&dns_resp);
Some(pkt)
}
fn ip_checksum(data: &[u8]) -> u16 {
let mut sum = 0u32;
for chunk in data.chunks(2) {
let word = if chunk.len() == 2 {
u16::from_be_bytes([chunk[0], chunk[1]]) as u32
} else {
(chunk[0] as u32) << 8
};
sum += word;
}
while sum >> 16 != 0 {
sum = (sum & 0xFFFF) + (sum >> 16);
}
!(sum as u16)
}
fn extract_tcp_syn(packet: &[u8]) -> Option<(IpEndpoint, IpEndpoint)> {
let ipv4 = Ipv4Packet::new_checked(packet).ok()?;
if ipv4.next_header() != IpProtocol::Tcp {
return None;
}
let tcp = TcpPacket::new_checked(ipv4.payload()).ok()?;
if !tcp.syn() || tcp.ack() {
return None;
}
let src = IpEndpoint::new(IpAddress::Ipv4(ipv4.src_addr()), tcp.src_port());
let dst = IpEndpoint::new(IpAddress::Ipv4(ipv4.dst_addr()), tcp.dst_port());
Some((src, dst))
}