1use std::{
21 collections::HashMap,
22 net::{IpAddr, Ipv4Addr},
23 os::unix::io::RawFd,
24 sync::Arc,
25 thread,
26};
27
28use anyhow::{Context, Result};
29use libc;
30use smoltcp::{
31 iface::{Config, Interface, SocketHandle, SocketSet},
32 phy::{wait as phy_wait, Device, DeviceCapabilities, Medium, RxToken, TxToken},
33 socket::tcp::{self as tcp, Socket as TcpSocket},
34 time::{Duration as SmolDuration, Instant as SmolInstant},
35 wire::{
36 HardwareAddress, IpAddress, IpCidr, IpEndpoint, IpListenEndpoint, IpProtocol, Ipv4Address,
37 Ipv4Packet, TcpPacket, UdpPacket,
38 },
39};
40use tokio::{
41 io::{AsyncReadExt, AsyncWriteExt},
42 runtime::Handle,
43 sync::mpsc,
44};
45use tracing::{debug, warn};
46
47use crate::{
48 chain::ChainEngine,
49 dns::DnsMap,
50 proxy::{BoxStream, Target},
51};
52
53pub use crate::chain::ChainConfig;
54
55struct IpDevice {
63 rx: Option<Vec<u8>>,
64 tx: Vec<Vec<u8>>,
65 mtu: usize,
66}
67
68impl IpDevice {
69 fn new(mtu: usize) -> Self {
70 Self {
71 rx: None,
72 tx: Vec::new(),
73 mtu,
74 }
75 }
76}
77
78struct OwnedRxToken(Vec<u8>);
79impl RxToken for OwnedRxToken {
80 fn consume<R, F: FnOnce(&mut [u8]) -> R>(mut self, f: F) -> R {
81 f(&mut self.0)
82 }
83}
84
85struct OwnedTxToken<'a>(&'a mut Vec<Vec<u8>>);
86impl<'a> TxToken for OwnedTxToken<'a> {
87 fn consume<R, F: FnOnce(&mut [u8]) -> R>(self, len: usize, f: F) -> R {
88 let mut buf = vec![0u8; len];
89 let r = f(&mut buf);
90 self.0.push(buf);
91 r
92 }
93}
94
95impl Device for IpDevice {
96 type RxToken<'a> = OwnedRxToken;
97 type TxToken<'a> = OwnedTxToken<'a>;
98
99 fn receive(&mut self, _: SmolInstant) -> Option<(OwnedRxToken, OwnedTxToken<'_>)> {
100 self.rx
101 .take()
102 .map(|p| (OwnedRxToken(p), OwnedTxToken(&mut self.tx)))
103 }
104
105 fn transmit(&mut self, _: SmolInstant) -> Option<OwnedTxToken<'_>> {
106 Some(OwnedTxToken(&mut self.tx))
107 }
108
109 fn capabilities(&self) -> DeviceCapabilities {
110 let mut caps = DeviceCapabilities::default();
111 caps.medium = Medium::Ip;
112 caps.max_transmission_unit = self.mtu;
113 caps
114 }
115}
116
117#[derive(Debug, Clone)]
121pub struct TunnelConfig {
122 pub chain: ChainConfig,
124 pub dns_map: DnsMap,
126 pub tun_ip: Ipv4Addr,
128 pub prefix_len: u8,
130 pub dns_ip: Option<Ipv4Addr>,
133}
134
135struct FlowInfo {
138 to_relay: mpsc::Sender<Vec<u8>>,
140 from_relay: mpsc::Receiver<Vec<u8>>,
142}
143
144pub struct ProxyChainTunnel {
149 config: TunnelConfig,
150}
151
152impl ProxyChainTunnel {
153 pub fn new(config: TunnelConfig) -> Self {
155 Self { config }
156 }
157
158 pub fn spawn(self, tun_fd: RawFd, rt: Handle) -> thread::JoinHandle<Result<()>> {
163 thread::spawn(move || poll_loop(self.config, tun_fd, rt))
164 }
165
166 pub async fn run(self, tun_fd: RawFd) -> Result<()> {
176 let rt = Handle::current();
177 let config = self.config;
178 let join = thread::spawn(move || poll_loop(config, tun_fd, rt));
181 tokio::task::spawn_blocking(move || join.join())
182 .await
183 .context("tunnel thread panicked")?
184 .map_err(|_| anyhow::anyhow!("tunnel thread panicked"))
185 .and_then(|r| r)
186 }
187}
188
189const CHANNEL_CAPACITY: usize = 64;
193const SOCKET_BUF: usize = 64 * 1024;
195const MAX_POLL_WAIT_MS: u64 = 5;
197
198fn poll_loop(config: TunnelConfig, tun_fd: RawFd, rt: Handle) -> Result<()> {
199 let mtu = 1500usize;
200
201 let mut device = IpDevice::new(mtu);
203
204 let mut iface_cfg = Config::new(HardwareAddress::Ip);
205 iface_cfg.random_seed = 0xdeadbeef;
206 let mut iface = Interface::new(iface_cfg, &mut device, SmolInstant::now());
207 iface.set_any_ip(true);
209 iface.update_ip_addrs(|addrs| {
210 let cidr = IpCidr::new(IpAddress::Ipv4(config.tun_ip.into()), config.prefix_len);
211 let _ = addrs.push(cidr);
212 });
213 iface
217 .routes_mut()
218 .add_default_ipv4_route(Ipv4Address::from(config.tun_ip))
219 .expect("smoltcp route table is full");
220
221 let mut sockets = SocketSet::new(vec![]);
222 let engine = Arc::new(ChainEngine::new(config.chain));
223 let dns_map = config.dns_map;
224 let dns_ip = config.dns_ip;
225
226 let mut pending: HashMap<(IpEndpoint, IpEndpoint), SocketHandle> = HashMap::new();
229
230 let mut active: HashMap<SocketHandle, FlowInfo> = HashMap::new();
232
233 let mut read_buf = vec![0u8; mtu + 4];
234
235 loop {
236 let delay = iface
238 .poll_delay(SmolInstant::now(), &sockets)
239 .map(|d| d.min(SmolDuration::from_millis(MAX_POLL_WAIT_MS)));
240 phy_wait(tun_fd, delay).ok();
241
242 let maybe_packet = {
246 let n = unsafe { libc::read(tun_fd, read_buf.as_mut_ptr().cast(), read_buf.len()) };
247 if n > 0 {
248 Some(read_buf[..n as usize].to_vec())
249 } else {
250 None
251 }
252 };
253
254 if let Some(pkt) = maybe_packet {
255 if let Some(dns_ip) = dns_ip {
257 if let Some(resp) = try_handle_dns(&pkt, dns_ip, &dns_map) {
258 unsafe {
259 libc::write(tun_fd, resp.as_ptr().cast(), resp.len());
260 }
261 continue;
263 }
264 }
265
266 if let Some((src, dst)) = extract_tcp_syn(&pkt) {
269 if let std::collections::hash_map::Entry::Vacant(e) = pending.entry((src, dst)) {
270 let listen_dst = e.key().1;
272 let mut sock = TcpSocket::new(
273 tcp::SocketBuffer::new(vec![0u8; SOCKET_BUF]),
274 tcp::SocketBuffer::new(vec![0u8; SOCKET_BUF]),
275 );
276 let listen_ep = IpListenEndpoint {
280 addr: Some(listen_dst.addr),
281 port: listen_dst.port,
282 };
283 if sock.listen(listen_ep).is_ok() {
284 let handle = sockets.add(sock);
285 e.insert(handle);
286 debug!("new TCP flow {src} → {listen_dst}");
287 }
288 }
289 }
290 device.rx = Some(pkt);
292 }
293
294 iface.poll(SmolInstant::now(), &mut device, &mut sockets);
296
297 for pkt in device.tx.drain(..) {
299 unsafe {
301 libc::write(tun_fd, pkt.as_ptr().cast(), pkt.len());
302 }
303 }
304
305 pending.retain(|(_src, dst), &mut handle| {
307 let sock = sockets.get_mut::<TcpSocket>(handle);
308 if !sock.may_send() {
312 return true;
313 }
314
315 let target = {
317 let ip: IpAddr = match dst.addr {
318 IpAddress::Ipv4(a) => IpAddr::V4(Ipv4Addr::from(a)),
319 IpAddress::Ipv6(a) => IpAddr::V6(a.into()),
320 };
321 let hostname = if let IpAddr::V4(v4) = ip {
322 dns_map.lookup_hostname(v4)
323 } else {
324 None
325 };
326 match hostname {
327 Some(h) => Target::Host(h, dst.port),
328 None => Target::Ip(ip, dst.port),
329 }
330 };
331
332 let (to_relay_tx, to_relay_rx) = mpsc::channel::<Vec<u8>>(CHANNEL_CAPACITY);
334 let (from_relay_tx, from_relay_rx) = mpsc::channel::<Vec<u8>>(CHANNEL_CAPACITY);
335
336 let eng = engine.clone();
337 rt.spawn(relay_flow(eng, target, to_relay_rx, from_relay_tx));
338
339 active.insert(
340 handle,
341 FlowInfo {
342 to_relay: to_relay_tx,
343 from_relay: from_relay_rx,
344 },
345 );
346 false });
348
349 active.retain(|&handle, info| {
351 let sock = sockets.get_mut::<TcpSocket>(handle);
352
353 if sock.can_recv() {
355 if let Ok(data) = sock.recv(|b| (b.len(), b.to_vec())) {
356 if !data.is_empty() {
357 let _ = info.to_relay.try_send(data);
358 }
359 }
360 }
361
362 while let Ok(data) = info.from_relay.try_recv() {
364 let _ = sock.send_slice(&data);
365 }
366
367 if matches!(
369 sock.state(),
370 tcp::State::Closed | tcp::State::CloseWait | tcp::State::TimeWait
371 ) && !sock.can_recv()
372 {
373 sockets.remove(handle);
374 return false;
375 }
376 true
377 });
378 }
379}
380
381async fn relay_flow(
386 engine: Arc<ChainEngine>,
387 target: Target,
388 mut smol_rx: mpsc::Receiver<Vec<u8>>,
389 smol_tx: mpsc::Sender<Vec<u8>>,
390) {
391 let proxy: BoxStream = match engine.connect(target.clone()).await {
392 Ok(s) => s,
393 Err(e) => {
394 warn!("relay: failed to connect to {target}: {e:#}");
395 return;
396 }
397 };
398
399 let (mut proxy_r, mut proxy_w) = tokio::io::split(proxy);
400 let mut buf = vec![0u8; 8192];
401
402 loop {
403 tokio::select! {
404 data = smol_rx.recv() => {
406 match data {
407 Some(d) => { if proxy_w.write_all(&d).await.is_err() { break; } }
408 None => break,
409 }
410 }
411 n = proxy_r.read(&mut buf) => {
413 match n {
414 Ok(0) | Err(_) => break,
415 Ok(n) => {
416 if smol_tx.send(buf[..n].to_vec()).await.is_err() {
417 break;
418 }
419 }
420 }
421 }
422 }
423 }
424 debug!("relay flow ended for {target}");
425}
426
427fn try_handle_dns(packet: &[u8], dns_ip: Ipv4Addr, dns_map: &DnsMap) -> Option<Vec<u8>> {
432 let ipv4 = Ipv4Packet::new_checked(packet).ok()?;
433 if ipv4.next_header() != IpProtocol::Udp {
434 return None;
435 }
436 if Ipv4Addr::from(ipv4.dst_addr()) != dns_ip {
437 return None;
438 }
439 let udp = UdpPacket::new_checked(ipv4.payload()).ok()?;
440 if udp.dst_port() != 53 {
441 return None;
442 }
443 let dns_resp = dns_map.handle_dns_query(udp.payload())?;
445
446 let udp_len = 8u16 + dns_resp.len() as u16;
448 let total_len = 20u16 + udp_len;
449 let mut pkt = vec![0u8; total_len as usize];
450
451 pkt[0] = 0x45; pkt[2] = (total_len >> 8) as u8;
454 pkt[3] = total_len as u8;
455 pkt[8] = 64; pkt[9] = 17; pkt[12..16].copy_from_slice(&ipv4.dst_addr().0); pkt[16..20].copy_from_slice(&ipv4.src_addr().0); let csum = ip_checksum(&pkt[0..20]);
460 pkt[10] = (csum >> 8) as u8;
461 pkt[11] = csum as u8;
462
463 pkt[20] = 0;
465 pkt[21] = 53; pkt[22] = (udp.src_port() >> 8) as u8;
467 pkt[23] = udp.src_port() as u8; pkt[24] = (udp_len >> 8) as u8;
469 pkt[25] = udp_len as u8;
470 pkt[28..].copy_from_slice(&dns_resp);
472
473 Some(pkt)
474}
475
476fn ip_checksum(data: &[u8]) -> u16 {
477 let mut sum = 0u32;
478 for chunk in data.chunks(2) {
479 let word = if chunk.len() == 2 {
480 u16::from_be_bytes([chunk[0], chunk[1]]) as u32
481 } else {
482 (chunk[0] as u32) << 8
483 };
484 sum += word;
485 }
486 while sum >> 16 != 0 {
487 sum = (sum & 0xFFFF) + (sum >> 16);
488 }
489 !(sum as u16)
490}
491
492fn extract_tcp_syn(packet: &[u8]) -> Option<(IpEndpoint, IpEndpoint)> {
494 let ipv4 = Ipv4Packet::new_checked(packet).ok()?;
495 if ipv4.next_header() != IpProtocol::Tcp {
496 return None;
497 }
498 let tcp = TcpPacket::new_checked(ipv4.payload()).ok()?;
499 if !tcp.syn() || tcp.ack() {
500 return None;
501 }
502 let src = IpEndpoint::new(IpAddress::Ipv4(ipv4.src_addr()), tcp.src_port());
503 let dst = IpEndpoint::new(IpAddress::Ipv4(ipv4.dst_addr()), tcp.dst_port());
504 Some((src, dst))
505}