1use std::{
21 collections::HashMap,
22 fs::File,
23 io::{Read, Write},
24 net::{IpAddr, Ipv4Addr},
25 os::unix::io::AsRawFd,
26 sync::Arc,
27 thread,
28};
29
30use anyhow::{Context, Result};
31use rand::{rngs::OsRng, RngCore};
32use smoltcp::{
33 iface::{Config, Interface, SocketHandle, SocketSet},
34 phy::{wait as phy_wait, Device, DeviceCapabilities, Medium, RxToken, TxToken},
35 socket::tcp::{self as tcp, Socket as TcpSocket},
36 time::{Duration as SmolDuration, Instant as SmolInstant},
37 wire::{
38 HardwareAddress, IpAddress, IpCidr, IpEndpoint, IpListenEndpoint, IpProtocol, Ipv4Address,
39 Ipv4Packet, TcpPacket, UdpPacket,
40 },
41};
42use tokio::{
43 io::{AsyncReadExt, AsyncWriteExt},
44 runtime::Handle,
45 sync::mpsc,
46};
47use tracing::{debug, warn};
48
49use crate::{
50 chain::ChainEngine,
51 dns::DnsMap,
52 proxy::{BoxStream, Target},
53};
54
55pub use crate::chain::ChainConfig;
56
57struct IpDevice {
65 rx: Option<Vec<u8>>,
66 tx: Vec<Vec<u8>>,
67 mtu: usize,
68}
69
70impl IpDevice {
71 fn new(mtu: usize) -> Self {
72 Self {
73 rx: None,
74 tx: Vec::new(),
75 mtu,
76 }
77 }
78}
79
80struct OwnedRxToken(Vec<u8>);
81impl RxToken for OwnedRxToken {
82 fn consume<R, F: FnOnce(&mut [u8]) -> R>(mut self, f: F) -> R {
83 f(&mut self.0)
84 }
85}
86
87struct OwnedTxToken<'a>(&'a mut Vec<Vec<u8>>);
88impl<'a> TxToken for OwnedTxToken<'a> {
89 fn consume<R, F: FnOnce(&mut [u8]) -> R>(self, len: usize, f: F) -> R {
90 let mut buf = vec![0u8; len];
91 let r = f(&mut buf);
92 self.0.push(buf);
93 r
94 }
95}
96
97impl Device for IpDevice {
98 type RxToken<'a> = OwnedRxToken;
99 type TxToken<'a> = OwnedTxToken<'a>;
100
101 fn receive(&mut self, _: SmolInstant) -> Option<(OwnedRxToken, OwnedTxToken<'_>)> {
102 self.rx
103 .take()
104 .map(|p| (OwnedRxToken(p), OwnedTxToken(&mut self.tx)))
105 }
106
107 fn transmit(&mut self, _: SmolInstant) -> Option<OwnedTxToken<'_>> {
108 Some(OwnedTxToken(&mut self.tx))
109 }
110
111 fn capabilities(&self) -> DeviceCapabilities {
112 let mut caps = DeviceCapabilities::default();
113 caps.medium = Medium::Ip;
114 caps.max_transmission_unit = self.mtu;
115 caps
116 }
117}
118
119#[derive(Debug, Clone)]
123pub struct TunnelConfig {
124 pub chain: ChainConfig,
126 pub dns_map: DnsMap,
128 pub tun_ip: Ipv4Addr,
130 pub prefix_len: u8,
132 pub dns_ip: Option<Ipv4Addr>,
135}
136
137struct FlowInfo {
140 to_relay: mpsc::Sender<Vec<u8>>,
142 from_relay: mpsc::Receiver<Vec<u8>>,
144}
145
146pub struct ProxyChainTunnel {
151 config: TunnelConfig,
152}
153
154impl ProxyChainTunnel {
155 pub fn new(config: TunnelConfig) -> Self {
157 Self { config }
158 }
159
160 pub fn spawn(self, tun_file: File, rt: Handle) -> thread::JoinHandle<Result<()>> {
165 thread::spawn(move || poll_loop(self.config, tun_file, rt))
166 }
167
168 pub async fn run(self, tun_file: File) -> Result<()> {
178 let rt = Handle::current();
179 let config = self.config;
180 let join = thread::spawn(move || poll_loop(config, tun_file, rt));
183 tokio::task::spawn_blocking(move || join.join())
184 .await
185 .context("tunnel thread panicked")?
186 .map_err(|_| anyhow::anyhow!("tunnel thread panicked"))
187 .and_then(|r| r)
188 }
189}
190
191const CHANNEL_CAPACITY: usize = 64;
195const SOCKET_BUF: usize = 64 * 1024;
197const MAX_POLL_WAIT_MS: u64 = 5;
199
200fn poll_loop(config: TunnelConfig, mut tun_file: File, rt: Handle) -> Result<()> {
201 let mtu = 1500usize;
202
203 let tun_raw_fd = tun_file.as_raw_fd();
206
207 let mut device = IpDevice::new(mtu);
209
210 let mut iface_cfg = Config::new(HardwareAddress::Ip);
211 iface_cfg.random_seed = OsRng.next_u64();
213 let mut iface = Interface::new(iface_cfg, &mut device, SmolInstant::now());
214 iface.set_any_ip(true);
216 iface.update_ip_addrs(|addrs| {
217 let cidr = IpCidr::new(IpAddress::Ipv4(config.tun_ip.into()), config.prefix_len);
218 let _ = addrs.push(cidr);
219 });
220 iface
224 .routes_mut()
225 .add_default_ipv4_route(Ipv4Address::from(config.tun_ip))
226 .expect("smoltcp route table is full");
227
228 let mut sockets = SocketSet::new(vec![]);
229 let engine = Arc::new(ChainEngine::new(config.chain));
230 let dns_map = config.dns_map;
231 let dns_ip = config.dns_ip;
232
233 let mut pending: HashMap<(IpEndpoint, IpEndpoint), SocketHandle> = HashMap::new();
236
237 let mut active: HashMap<SocketHandle, FlowInfo> = HashMap::new();
239
240 let mut read_buf = vec![0u8; mtu + 4];
241
242 loop {
243 let delay = iface
245 .poll_delay(SmolInstant::now(), &sockets)
246 .map(|d| d.min(SmolDuration::from_millis(MAX_POLL_WAIT_MS)));
247 phy_wait(tun_raw_fd, delay).ok();
248
249 let maybe_packet = match tun_file.read(&mut read_buf) {
253 Ok(n) if n > 0 => Some(read_buf[..n].to_vec()),
254 _ => None,
255 };
256
257 if let Some(pkt) = maybe_packet {
258 if let Some(dns_ip) = dns_ip {
260 if let Some(resp) = try_handle_dns(&pkt, dns_ip, &dns_map) {
261 let _ = tun_file.write_all(&resp);
262 continue;
264 }
265 }
266
267 if let Some((src, dst)) = extract_tcp_syn(&pkt) {
270 if let std::collections::hash_map::Entry::Vacant(e) = pending.entry((src, dst)) {
271 let listen_dst = e.key().1;
273 let mut sock = TcpSocket::new(
274 tcp::SocketBuffer::new(vec![0u8; SOCKET_BUF]),
275 tcp::SocketBuffer::new(vec![0u8; SOCKET_BUF]),
276 );
277 let listen_ep = IpListenEndpoint {
281 addr: Some(listen_dst.addr),
282 port: listen_dst.port,
283 };
284 if sock.listen(listen_ep).is_ok() {
285 let handle = sockets.add(sock);
286 e.insert(handle);
287 debug!("new TCP flow {src} → {listen_dst}");
288 }
289 }
290 }
291 device.rx = Some(pkt);
293 }
294
295 iface.poll(SmolInstant::now(), &mut device, &mut sockets);
297
298 for pkt in device.tx.drain(..) {
300 let _ = tun_file.write_all(&pkt);
301 }
302
303 pending.retain(|(_src, dst), &mut handle| {
305 let sock = sockets.get_mut::<TcpSocket>(handle);
306 if !sock.may_send() {
310 return true;
311 }
312
313 let target = {
315 let ip: IpAddr = match dst.addr {
316 IpAddress::Ipv4(a) => IpAddr::V4(Ipv4Addr::from(a)),
317 IpAddress::Ipv6(a) => IpAddr::V6(a.into()),
318 };
319 let hostname = if let IpAddr::V4(v4) = ip {
320 dns_map.lookup_hostname(v4)
321 } else {
322 None
323 };
324 match hostname {
325 Some(h) => Target::Host(h, dst.port),
326 None => Target::Ip(ip, dst.port),
327 }
328 };
329
330 let (to_relay_tx, to_relay_rx) = mpsc::channel::<Vec<u8>>(CHANNEL_CAPACITY);
332 let (from_relay_tx, from_relay_rx) = mpsc::channel::<Vec<u8>>(CHANNEL_CAPACITY);
333
334 let eng = engine.clone();
335 rt.spawn(relay_flow(eng, target, to_relay_rx, from_relay_tx));
336
337 active.insert(
338 handle,
339 FlowInfo {
340 to_relay: to_relay_tx,
341 from_relay: from_relay_rx,
342 },
343 );
344 false });
346
347 active.retain(|&handle, info| {
349 let sock = sockets.get_mut::<TcpSocket>(handle);
350
351 if sock.can_recv() {
353 if let Ok(data) = sock.recv(|b| (b.len(), b.to_vec())) {
354 if !data.is_empty() {
355 let _ = info.to_relay.try_send(data);
356 }
357 }
358 }
359
360 while let Ok(data) = info.from_relay.try_recv() {
362 let _ = sock.send_slice(&data);
363 }
364
365 if matches!(
367 sock.state(),
368 tcp::State::Closed | tcp::State::CloseWait | tcp::State::TimeWait
369 ) && !sock.can_recv()
370 {
371 sockets.remove(handle);
372 return false;
373 }
374 true
375 });
376 }
377}
378
379async fn relay_flow(
384 engine: Arc<ChainEngine>,
385 target: Target,
386 mut smol_rx: mpsc::Receiver<Vec<u8>>,
387 smol_tx: mpsc::Sender<Vec<u8>>,
388) {
389 let proxy: BoxStream = match engine.connect(target.clone()).await {
390 Ok(s) => s,
391 Err(e) => {
392 warn!("relay: failed to connect to {target}: {e:#}");
393 return;
394 }
395 };
396
397 let (mut proxy_r, mut proxy_w) = tokio::io::split(proxy);
398 let mut buf = vec![0u8; 8192];
399
400 loop {
401 tokio::select! {
402 data = smol_rx.recv() => {
404 match data {
405 Some(d) => { if proxy_w.write_all(&d).await.is_err() { break; } }
406 None => break,
407 }
408 }
409 n = proxy_r.read(&mut buf) => {
411 match n {
412 Ok(0) | Err(_) => break,
413 Ok(n) => {
414 if smol_tx.send(buf[..n].to_vec()).await.is_err() {
415 break;
416 }
417 }
418 }
419 }
420 }
421 }
422 debug!("relay flow ended for {target}");
423}
424
425fn try_handle_dns(packet: &[u8], dns_ip: Ipv4Addr, dns_map: &DnsMap) -> Option<Vec<u8>> {
430 let ipv4 = Ipv4Packet::new_checked(packet).ok()?;
431 if ipv4.next_header() != IpProtocol::Udp {
432 return None;
433 }
434 if Ipv4Addr::from(ipv4.dst_addr()) != dns_ip {
435 return None;
436 }
437 let udp = UdpPacket::new_checked(ipv4.payload()).ok()?;
438 if udp.dst_port() != 53 {
439 return None;
440 }
441 let dns_resp = dns_map.handle_dns_query(udp.payload())?;
443
444 let udp_len = 8u16 + dns_resp.len() as u16;
446 let total_len = 20u16 + udp_len;
447 let mut pkt = vec![0u8; total_len as usize];
448
449 pkt[0] = 0x45; pkt[2] = (total_len >> 8) as u8;
452 pkt[3] = total_len as u8;
453 pkt[8] = 64; pkt[9] = 17; pkt[12..16].copy_from_slice(&ipv4.dst_addr().0); pkt[16..20].copy_from_slice(&ipv4.src_addr().0); let csum = ip_checksum(&pkt[0..20]);
458 pkt[10] = (csum >> 8) as u8;
459 pkt[11] = csum as u8;
460
461 pkt[20] = 0;
463 pkt[21] = 53; pkt[22] = (udp.src_port() >> 8) as u8;
465 pkt[23] = udp.src_port() as u8; pkt[24] = (udp_len >> 8) as u8;
467 pkt[25] = udp_len as u8;
468 pkt[28..].copy_from_slice(&dns_resp);
470
471 Some(pkt)
472}
473
474fn ip_checksum(data: &[u8]) -> u16 {
475 let mut sum = 0u32;
476 for chunk in data.chunks(2) {
477 let word = if chunk.len() == 2 {
478 u16::from_be_bytes([chunk[0], chunk[1]]) as u32
479 } else {
480 (chunk[0] as u32) << 8
481 };
482 sum += word;
483 }
484 while sum >> 16 != 0 {
485 sum = (sum & 0xFFFF) + (sum >> 16);
486 }
487 !(sum as u16)
488}
489
490fn extract_tcp_syn(packet: &[u8]) -> Option<(IpEndpoint, IpEndpoint)> {
492 let ipv4 = Ipv4Packet::new_checked(packet).ok()?;
493 if ipv4.next_header() != IpProtocol::Tcp {
494 return None;
495 }
496 let tcp = TcpPacket::new_checked(ipv4.payload()).ok()?;
497 if !tcp.syn() || tcp.ack() {
498 return None;
499 }
500 let src = IpEndpoint::new(IpAddress::Ipv4(ipv4.src_addr()), tcp.src_port());
501 let dst = IpEndpoint::new(IpAddress::Ipv4(ipv4.dst_addr()), tcp.dst_port());
502 Some((src, dst))
503}