tun2proxy/
lib.rs

1#[cfg(feature = "udpgw")]
2use crate::udpgw::UdpGwClient;
3use crate::{
4    directions::{IncomingDataEvent, IncomingDirection, OutgoingDirection},
5    http::HttpManager,
6    no_proxy::NoProxyManager,
7    session_info::{IpProtocol, SessionInfo},
8    virtual_dns::VirtualDns,
9};
10use ipstack::{IpStackStream, IpStackTcpStream, IpStackUdpStream};
11use proxy_handler::{ProxyHandler, ProxyHandlerManager};
12use socks::SocksProxyManager;
13pub use socks5_impl::protocol::UserKey;
14#[cfg(feature = "udpgw")]
15use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
16use std::{
17    collections::VecDeque,
18    io::ErrorKind,
19    net::{IpAddr, SocketAddr},
20    sync::Arc,
21};
22use tokio::{
23    io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
24    net::{TcpSocket, TcpStream, UdpSocket},
25    sync::{Mutex, mpsc::Receiver},
26};
27pub use tokio_util::sync::CancellationToken;
28use tproxy_config::is_private_ip;
29use udp_stream::UdpStream;
30#[cfg(feature = "udpgw")]
31use udpgw::{UDPGW_KEEPALIVE_TIME, UDPGW_MAX_CONNECTIONS, UdpGwClientStream, UdpGwResponse};
32
33pub use {
34    args::{ArgDns, ArgProxy, ArgVerbosity, Args, ProxyType},
35    error::{BoxError, Error, Result},
36    traffic_status::{TrafficStatus, tun2proxy_set_traffic_status_callback},
37};
38
39#[cfg(feature = "mimalloc")]
40#[global_allocator]
41static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc;
42
43pub use general_api::general_run_async;
44
45mod android;
46mod args;
47mod directions;
48mod dns;
49mod dump_logger;
50mod error;
51mod general_api;
52mod http;
53mod no_proxy;
54mod proxy_handler;
55mod session_info;
56pub mod socket_transfer;
57mod socks;
58mod traffic_status;
59#[cfg(feature = "udpgw")]
60pub mod udpgw;
61mod virtual_dns;
62#[doc(hidden)]
63pub mod win_svc;
64
65const DNS_PORT: u16 = 53;
66
67#[allow(unused)]
68#[derive(Hash, Copy, Clone, Eq, PartialEq, Debug)]
69#[cfg_attr(
70    target_os = "linux",
71    derive(bincode::Encode, bincode::Decode, serde::Serialize, serde::Deserialize)
72)]
73pub enum SocketProtocol {
74    Tcp,
75    Udp,
76}
77
78#[allow(unused)]
79#[derive(Hash, Copy, Clone, Eq, PartialEq, Debug)]
80#[cfg_attr(
81    target_os = "linux",
82    derive(bincode::Encode, bincode::Decode, serde::Serialize, serde::Deserialize)
83)]
84pub enum SocketDomain {
85    IpV4,
86    IpV6,
87}
88
89impl From<IpAddr> for SocketDomain {
90    fn from(value: IpAddr) -> Self {
91        match value {
92            IpAddr::V4(_) => Self::IpV4,
93            IpAddr::V6(_) => Self::IpV6,
94        }
95    }
96}
97
98struct SocketQueue {
99    tcp_v4: Mutex<Receiver<TcpSocket>>,
100    tcp_v6: Mutex<Receiver<TcpSocket>>,
101    udp_v4: Mutex<Receiver<UdpSocket>>,
102    udp_v6: Mutex<Receiver<UdpSocket>>,
103}
104
105impl SocketQueue {
106    async fn recv_tcp(&self, domain: SocketDomain) -> Result<TcpSocket, std::io::Error> {
107        match domain {
108            SocketDomain::IpV4 => &self.tcp_v4,
109            SocketDomain::IpV6 => &self.tcp_v6,
110        }
111        .lock()
112        .await
113        .recv()
114        .await
115        .ok_or(ErrorKind::Other.into())
116    }
117    async fn recv_udp(&self, domain: SocketDomain) -> Result<UdpSocket, std::io::Error> {
118        match domain {
119            SocketDomain::IpV4 => &self.udp_v4,
120            SocketDomain::IpV6 => &self.udp_v6,
121        }
122        .lock()
123        .await
124        .recv()
125        .await
126        .ok_or(ErrorKind::Other.into())
127    }
128}
129
130async fn create_tcp_stream(socket_queue: &Option<Arc<SocketQueue>>, peer: SocketAddr) -> std::io::Result<TcpStream> {
131    match &socket_queue {
132        None => TcpStream::connect(peer).await,
133        Some(queue) => queue.recv_tcp(peer.ip().into()).await?.connect(peer).await,
134    }
135}
136
137async fn create_udp_stream(socket_queue: &Option<Arc<SocketQueue>>, peer: SocketAddr) -> std::io::Result<UdpStream> {
138    match &socket_queue {
139        None => UdpStream::connect(peer).await,
140        Some(queue) => {
141            let socket = queue.recv_udp(peer.ip().into()).await?;
142            socket.connect(peer).await?;
143            UdpStream::from_tokio(socket, peer).await
144        }
145    }
146}
147
148/// Run the proxy server
149/// # Arguments
150/// * `device` - The network device to use
151/// * `mtu` - The MTU of the network device
152/// * `args` - The arguments to use
153/// * `shutdown_token` - The token to exit the server
154/// # Returns
155/// * The number of sessions while exiting
156pub async fn run<D>(device: D, mtu: u16, args: Args, shutdown_token: CancellationToken) -> crate::Result<usize>
157where
158    D: AsyncRead + AsyncWrite + Unpin + Send + 'static,
159{
160    log::info!("{} {} starting...", env!("CARGO_PKG_NAME"), version_info!());
161    log::info!("Proxy {} server: {}", args.proxy.proxy_type, args.proxy.addr);
162
163    let server_addr = args.proxy.addr;
164    let key = args.proxy.credentials.clone();
165    let dns_addr = args.dns_addr;
166    let ipv6_enabled = args.ipv6_enabled;
167    let virtual_dns = if args.dns == ArgDns::Virtual {
168        Some(Arc::new(Mutex::new(VirtualDns::new(args.virtual_dns_pool))))
169    } else {
170        None
171    };
172
173    #[cfg(target_os = "linux")]
174    let socket_queue = match args.socket_transfer_fd {
175        None => None,
176        Some(fd) => {
177            use crate::socket_transfer::{reconstruct_socket, reconstruct_transfer_socket, request_sockets};
178            use tokio::sync::mpsc::channel;
179
180            let fd = reconstruct_socket(fd)?;
181            let socket = reconstruct_transfer_socket(fd)?;
182            let socket = Arc::new(Mutex::new(socket));
183
184            macro_rules! create_socket_queue {
185                ($domain:ident) => {{
186                    const SOCKETS_PER_REQUEST: usize = 64;
187
188                    let socket = socket.clone();
189                    let (tx, rx) = channel(SOCKETS_PER_REQUEST);
190                    tokio::spawn(async move {
191                        loop {
192                            let sockets =
193                                match request_sockets(socket.lock().await, SocketDomain::$domain, SOCKETS_PER_REQUEST as u32).await {
194                                    Ok(sockets) => sockets,
195                                    Err(err) => {
196                                        log::warn!("Socket allocation request failed: {err}");
197                                        continue;
198                                    }
199                                };
200                            for s in sockets {
201                                if let Err(_) = tx.send(s).await {
202                                    return;
203                                }
204                            }
205                        }
206                    });
207                    Mutex::new(rx)
208                }};
209            }
210
211            Some(Arc::new(SocketQueue {
212                tcp_v4: create_socket_queue!(IpV4),
213                tcp_v6: create_socket_queue!(IpV6),
214                udp_v4: create_socket_queue!(IpV4),
215                udp_v6: create_socket_queue!(IpV6),
216            }))
217        }
218    };
219
220    #[cfg(not(target_os = "linux"))]
221    let socket_queue = None;
222
223    use socks5_impl::protocol::Version::{V4, V5};
224    let mgr: Arc<dyn ProxyHandlerManager> = match args.proxy.proxy_type {
225        ProxyType::Socks5 => Arc::new(SocksProxyManager::new(server_addr, V5, key)),
226        ProxyType::Socks4 => Arc::new(SocksProxyManager::new(server_addr, V4, key)),
227        ProxyType::Http => Arc::new(HttpManager::new(server_addr, key)),
228        ProxyType::None => Arc::new(NoProxyManager::new()),
229    };
230
231    let mut ipstack_config = ipstack::IpStackConfig::default();
232    ipstack_config.mtu(mtu);
233    ipstack_config.tcp_timeout(std::time::Duration::from_secs(args.tcp_timeout));
234    ipstack_config.udp_timeout(std::time::Duration::from_secs(args.udp_timeout));
235
236    let mut ip_stack = ipstack::IpStack::new(ipstack_config, device);
237
238    #[cfg(feature = "udpgw")]
239    let udpgw_client = args.udpgw_server.map(|addr| {
240        log::info!("UDP Gateway enabled, server: {}", addr);
241        use std::time::Duration;
242        let client = Arc::new(UdpGwClient::new(
243            mtu,
244            args.udpgw_connections.unwrap_or(UDPGW_MAX_CONNECTIONS),
245            args.udpgw_keepalive.map(Duration::from_secs).unwrap_or(UDPGW_KEEPALIVE_TIME),
246            args.udp_timeout,
247            addr,
248        ));
249        let client_keepalive = client.clone();
250        tokio::spawn(async move {
251            let _ = client_keepalive.heartbeat_task().await;
252        });
253        client
254    });
255
256    let task_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
257    use std::sync::atomic::Ordering::Relaxed;
258
259    loop {
260        let task_count = task_count.clone();
261        let virtual_dns = virtual_dns.clone();
262        let ip_stack_stream = tokio::select! {
263            _ = shutdown_token.cancelled() => {
264                log::info!("Shutdown received");
265                break;
266            }
267            ip_stack_stream = ip_stack.accept() => {
268                ip_stack_stream?
269            }
270        };
271        let max_sessions = args.max_sessions;
272        match ip_stack_stream {
273            IpStackStream::Tcp(tcp) => {
274                if task_count.load(Relaxed) >= max_sessions {
275                    if args.exit_on_fatal_error {
276                        log::info!("Too many sessions that over {max_sessions}, exiting...");
277                        break;
278                    }
279                    log::warn!("Too many sessions that over {max_sessions}, dropping new session");
280                    continue;
281                }
282                log::trace!("Session count {}", task_count.fetch_add(1, Relaxed).saturating_add(1));
283                let info = SessionInfo::new(tcp.local_addr(), tcp.peer_addr(), IpProtocol::Tcp);
284                let domain_name = if let Some(virtual_dns) = &virtual_dns {
285                    let mut virtual_dns = virtual_dns.lock().await;
286                    virtual_dns.touch_ip(&tcp.peer_addr().ip());
287                    virtual_dns.resolve_ip(&tcp.peer_addr().ip()).cloned()
288                } else {
289                    None
290                };
291                let proxy_handler = mgr.new_proxy_handler(info, domain_name, false).await?;
292                let socket_queue = socket_queue.clone();
293                tokio::spawn(async move {
294                    if let Err(err) = handle_tcp_session(tcp, proxy_handler, socket_queue).await {
295                        log::error!("{} error \"{}\"", info, err);
296                    }
297                    log::trace!("Session count {}", task_count.fetch_sub(1, Relaxed).saturating_sub(1));
298                });
299            }
300            IpStackStream::Udp(udp) => {
301                if task_count.load(Relaxed) >= max_sessions {
302                    if args.exit_on_fatal_error {
303                        log::info!("Too many sessions that over {max_sessions}, exiting...");
304                        break;
305                    }
306                    log::warn!("Too many sessions that over {max_sessions}, dropping new session");
307                    continue;
308                }
309                log::trace!("Session count {}", task_count.fetch_add(1, Relaxed).saturating_add(1));
310                let mut info = SessionInfo::new(udp.local_addr(), udp.peer_addr(), IpProtocol::Udp);
311                if info.dst.port() == DNS_PORT {
312                    if is_private_ip(info.dst.ip()) {
313                        info.dst.set_ip(dns_addr); // !!! Here we change the destination address to remote DNS server!!!
314                    }
315                    if args.dns == ArgDns::OverTcp {
316                        info.protocol = IpProtocol::Tcp;
317                        let proxy_handler = mgr.new_proxy_handler(info, None, false).await?;
318                        let socket_queue = socket_queue.clone();
319                        tokio::spawn(async move {
320                            if let Err(err) = handle_dns_over_tcp_session(udp, proxy_handler, socket_queue, ipv6_enabled).await {
321                                log::error!("{} error \"{}\"", info, err);
322                            }
323                            log::trace!("Session count {}", task_count.fetch_sub(1, Relaxed).saturating_sub(1));
324                        });
325                        continue;
326                    }
327                    if args.dns == ArgDns::Virtual {
328                        tokio::spawn(async move {
329                            if let Some(virtual_dns) = virtual_dns {
330                                if let Err(err) = handle_virtual_dns_session(udp, virtual_dns).await {
331                                    log::error!("{} error \"{}\"", info, err);
332                                }
333                            }
334                            log::trace!("Session count {}", task_count.fetch_sub(1, Relaxed).saturating_sub(1));
335                        });
336                        continue;
337                    }
338                    assert_eq!(args.dns, ArgDns::Direct);
339                }
340                let domain_name = if let Some(virtual_dns) = &virtual_dns {
341                    let mut virtual_dns = virtual_dns.lock().await;
342                    virtual_dns.touch_ip(&udp.peer_addr().ip());
343                    virtual_dns.resolve_ip(&udp.peer_addr().ip()).cloned()
344                } else {
345                    None
346                };
347                #[cfg(feature = "udpgw")]
348                if let Some(udpgw) = udpgw_client.clone() {
349                    let tcp_src = match udp.peer_addr() {
350                        SocketAddr::V4(_) => SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
351                        SocketAddr::V6(_) => SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0)),
352                    };
353                    let tcpinfo = SessionInfo::new(tcp_src, udpgw.get_udpgw_server_addr(), IpProtocol::Tcp);
354                    let proxy_handler = mgr.new_proxy_handler(tcpinfo, None, false).await?;
355                    let queue = socket_queue.clone();
356                    tokio::spawn(async move {
357                        let dst = info.dst; // real UDP destination address
358                        let dst_addr = match domain_name {
359                            Some(ref d) => socks5_impl::protocol::Address::from((d.clone(), dst.port())),
360                            None => dst.into(),
361                        };
362                        if let Err(e) = handle_udp_gateway_session(udp, udpgw, &dst_addr, proxy_handler, queue, ipv6_enabled).await {
363                            log::info!("Ending {} with \"{}\"", info, e);
364                        }
365                        log::trace!("Session count {}", task_count.fetch_sub(1, Relaxed).saturating_sub(1));
366                    });
367                    continue;
368                }
369                match mgr.new_proxy_handler(info, domain_name, true).await {
370                    Ok(proxy_handler) => {
371                        let socket_queue = socket_queue.clone();
372                        tokio::spawn(async move {
373                            let ty = args.proxy.proxy_type;
374                            if let Err(err) = handle_udp_associate_session(udp, ty, proxy_handler, socket_queue, ipv6_enabled).await {
375                                log::info!("Ending {} with \"{}\"", info, err);
376                            }
377                            log::trace!("Session count {}", task_count.fetch_sub(1, Relaxed).saturating_sub(1));
378                        });
379                    }
380                    Err(e) => {
381                        log::error!("Failed to create UDP connection: {}", e);
382                    }
383                }
384            }
385            IpStackStream::UnknownTransport(u) => {
386                let len = u.payload().len();
387                log::info!("#0 unhandled transport - Ip Protocol {:?}, length {}", u.ip_protocol(), len);
388                continue;
389            }
390            IpStackStream::UnknownNetwork(pkt) => {
391                log::info!("#0 unknown transport - {} bytes", pkt.len());
392                continue;
393            }
394        }
395    }
396    Ok(task_count.load(Relaxed))
397}
398
399async fn handle_virtual_dns_session(mut udp: IpStackUdpStream, dns: Arc<Mutex<VirtualDns>>) -> crate::Result<()> {
400    let mut buf = [0_u8; 4096];
401    loop {
402        let len = match udp.read(&mut buf).await {
403            Err(e) => {
404                // indicate UDP read fails not an error.
405                log::debug!("Virtual DNS session error: {}", e);
406                break;
407            }
408            Ok(len) => len,
409        };
410        if len == 0 {
411            break;
412        }
413        let (msg, qname, ip) = dns.lock().await.generate_query(&buf[..len])?;
414        udp.write_all(&msg).await?;
415        log::debug!("Virtual DNS query: {} -> {}", qname, ip);
416    }
417    Ok(())
418}
419
420async fn copy_and_record_traffic<R, W>(reader: &mut R, writer: &mut W, is_tx: bool) -> tokio::io::Result<u64>
421where
422    R: tokio::io::AsyncRead + Unpin + ?Sized,
423    W: tokio::io::AsyncWrite + Unpin + ?Sized,
424{
425    let mut buf = vec![0; 8192];
426    let mut total = 0;
427    loop {
428        match reader.read(&mut buf).await? {
429            0 => break, // EOF
430            n => {
431                total += n as u64;
432                let (tx, rx) = if is_tx { (n, 0) } else { (0, n) };
433                if let Err(e) = crate::traffic_status::traffic_status_update(tx, rx) {
434                    log::debug!("Record traffic status error: {}", e);
435                }
436                writer.write_all(&buf[..n]).await?;
437            }
438        }
439    }
440    Ok(total)
441}
442
443async fn handle_tcp_session(
444    mut tcp_stack: IpStackTcpStream,
445    proxy_handler: Arc<Mutex<dyn ProxyHandler>>,
446    socket_queue: Option<Arc<SocketQueue>>,
447) -> crate::Result<()> {
448    let (session_info, server_addr) = {
449        let handler = proxy_handler.lock().await;
450
451        (handler.get_session_info(), handler.get_server_addr())
452    };
453
454    let mut server = create_tcp_stream(&socket_queue, server_addr).await?;
455
456    log::info!("Beginning {}", session_info);
457
458    if let Err(e) = handle_proxy_session(&mut server, proxy_handler).await {
459        tcp_stack.shutdown().await?;
460        return Err(e);
461    }
462
463    let (mut t_rx, mut t_tx) = tokio::io::split(tcp_stack);
464    let (mut s_rx, mut s_tx) = tokio::io::split(server);
465
466    let res = tokio::join!(
467        async move {
468            let r = copy_and_record_traffic(&mut t_rx, &mut s_tx, true).await;
469            if let Err(err) = s_tx.shutdown().await {
470                log::trace!("{} s_tx shutdown error {}", session_info, err);
471            }
472            r
473        },
474        async move {
475            let r = copy_and_record_traffic(&mut s_rx, &mut t_tx, false).await;
476            if let Err(err) = t_tx.shutdown().await {
477                log::trace!("{} t_tx shutdown error {}", session_info, err);
478            }
479            r
480        },
481    );
482    log::info!("Ending {} with {:?}", session_info, res);
483
484    Ok(())
485}
486
487#[cfg(feature = "udpgw")]
488async fn handle_udp_gateway_session(
489    mut udp_stack: IpStackUdpStream,
490    udpgw_client: Arc<UdpGwClient>,
491    udp_dst: &socks5_impl::protocol::Address,
492    proxy_handler: Arc<Mutex<dyn ProxyHandler>>,
493    socket_queue: Option<Arc<SocketQueue>>,
494    ipv6_enabled: bool,
495) -> crate::Result<()> {
496    let proxy_server_addr = { proxy_handler.lock().await.get_server_addr() };
497    let udp_mtu = udpgw_client.get_udp_mtu();
498    let udp_timeout = udpgw_client.get_udp_timeout();
499
500    let mut stream = loop {
501        match udpgw_client.pop_server_connection_from_queue().await {
502            Some(stream) => {
503                if stream.is_closed() {
504                    continue;
505                } else {
506                    break stream;
507                }
508            }
509            None => {
510                let mut tcp_server_stream = create_tcp_stream(&socket_queue, proxy_server_addr).await?;
511                if let Err(e) = handle_proxy_session(&mut tcp_server_stream, proxy_handler).await {
512                    return Err(format!("udpgw connection error: {}", e).into());
513                }
514                break UdpGwClientStream::new(tcp_server_stream);
515            }
516        }
517    };
518
519    let tcp_local_addr = stream.local_addr();
520    let sn = stream.serial_number();
521
522    log::info!("[UdpGw] Beginning stream {} {} -> {}", sn, &tcp_local_addr, udp_dst);
523
524    let Some(mut reader) = stream.get_reader() else {
525        return Err("get reader failed".into());
526    };
527
528    let Some(mut writer) = stream.get_writer() else {
529        return Err("get writer failed".into());
530    };
531
532    let mut tmp_buf = vec![0; udp_mtu.into()];
533
534    loop {
535        tokio::select! {
536            len = udp_stack.read(&mut tmp_buf) => {
537                let read_len = match len {
538                    Ok(0) => {
539                        log::info!("[UdpGw] Ending stream {} {} <> {}", sn, &tcp_local_addr, udp_dst);
540                        break;
541                    }
542                    Ok(n) => n,
543                    Err(e) => {
544                        log::info!("[UdpGw] Ending stream {} {} <> {} with udp stack \"{}\"", sn, &tcp_local_addr, udp_dst, e);
545                        break;
546                    }
547                };
548                crate::traffic_status::traffic_status_update(read_len, 0)?;
549                let sn = stream.serial_number();
550                if let Err(e) = UdpGwClient::send_udpgw_packet(ipv6_enabled, &tmp_buf[0..read_len], udp_dst, sn, &mut writer).await {
551                    log::info!("[UdpGw] Ending stream {} {} <> {} with send_udpgw_packet {}", sn, &tcp_local_addr, udp_dst, e);
552                    break;
553                }
554                log::debug!("[UdpGw] stream {} {} -> {} send len {}", sn, &tcp_local_addr, udp_dst, read_len);
555                stream.update_activity();
556            }
557            ret = UdpGwClient::recv_udpgw_packet(udp_mtu, udp_timeout, &mut reader) => {
558                if let Ok((len, _)) = ret {
559                    crate::traffic_status::traffic_status_update(0, len)?;
560                }
561                match ret {
562                    Err(e) => {
563                        log::warn!("[UdpGw] Ending stream {} {} <> {} with recv_udpgw_packet {}", sn, &tcp_local_addr, udp_dst, e);
564                        stream.close();
565                        break;
566                    }
567                    Ok((_, packet)) => match packet {
568                        //should not received keepalive
569                        UdpGwResponse::KeepAlive => {
570                            log::error!("[UdpGw] Ending stream {} {} <> {} with recv keepalive", sn, &tcp_local_addr, udp_dst);
571                            stream.close();
572                            break;
573                        }
574                        //server udp may be timeout,can continue to receive udp data?
575                        UdpGwResponse::Error => {
576                            log::info!("[UdpGw] Ending stream {} {} <> {} with recv udp error", sn, &tcp_local_addr, udp_dst);
577                            stream.update_activity();
578                            continue;
579                        }
580                        UdpGwResponse::TcpClose => {
581                            log::error!("[UdpGw] Ending stream {} {} <> {} with tcp closed", sn, &tcp_local_addr, udp_dst);
582                            stream.close();
583                            break;
584                        }
585                        UdpGwResponse::Data(data) => {
586                            use socks5_impl::protocol::StreamOperation;
587                            let len = data.len();
588                            let f = data.header.flags;
589                            log::debug!("[UdpGw] stream {sn} {} <- {} receive {f} len {len}", &tcp_local_addr, udp_dst);
590                            if let Err(e) = udp_stack.write_all(&data.data).await {
591                                log::error!("[UdpGw] Ending stream {} {} <> {} with send_udp_packet {}", sn, &tcp_local_addr, udp_dst, e);
592                                break;
593                            }
594                        }
595                    }
596                }
597                stream.update_activity();
598            }
599        }
600    }
601
602    if !stream.is_closed() {
603        udpgw_client.store_server_connection_full(stream, reader, writer).await;
604    }
605
606    Ok(())
607}
608
609async fn handle_udp_associate_session(
610    mut udp_stack: IpStackUdpStream,
611    proxy_type: ProxyType,
612    proxy_handler: Arc<Mutex<dyn ProxyHandler>>,
613    socket_queue: Option<Arc<SocketQueue>>,
614    ipv6_enabled: bool,
615) -> crate::Result<()> {
616    use socks5_impl::protocol::{Address, StreamOperation, UdpHeader};
617
618    let (session_info, server_addr, domain_name, udp_addr) = {
619        let handler = proxy_handler.lock().await;
620        (
621            handler.get_session_info(),
622            handler.get_server_addr(),
623            handler.get_domain_name(),
624            handler.get_udp_associate(),
625        )
626    };
627
628    log::info!("Beginning {}", session_info);
629
630    // `_server` is meaningful here, it must be alive all the time
631    // to ensure that UDP transmission will not be interrupted accidentally.
632    let (_server, udp_addr) = match udp_addr {
633        Some(udp_addr) => (None, udp_addr),
634        None => {
635            let mut server = create_tcp_stream(&socket_queue, server_addr).await?;
636            let udp_addr = handle_proxy_session(&mut server, proxy_handler).await?;
637            (Some(server), udp_addr.ok_or("udp associate failed")?)
638        }
639    };
640
641    let mut udp_server = create_udp_stream(&socket_queue, udp_addr).await?;
642
643    let mut buf1 = [0_u8; 4096];
644    let mut buf2 = [0_u8; 4096];
645    loop {
646        tokio::select! {
647            len = udp_stack.read(&mut buf1) => {
648                let len = len?;
649                if len == 0 {
650                    break;
651                }
652                let buf1 = &buf1[..len];
653
654                crate::traffic_status::traffic_status_update(len, 0)?;
655
656                if let ProxyType::Socks4 | ProxyType::Socks5 = proxy_type {
657                    let s5addr = if let Some(domain_name) = &domain_name {
658                        Address::DomainAddress(domain_name.clone(), session_info.dst.port())
659                    } else {
660                        session_info.dst.into()
661                    };
662
663                    // Add SOCKS5 UDP header to the incoming data
664                    let mut s5_udp_data = Vec::<u8>::new();
665                    UdpHeader::new(0, s5addr).write_to_stream(&mut s5_udp_data)?;
666                    s5_udp_data.extend_from_slice(buf1);
667
668                    udp_server.write_all(&s5_udp_data).await?;
669                } else {
670                    udp_server.write_all(buf1).await?;
671                }
672            }
673            len = udp_server.read(&mut buf2) => {
674                let len = len?;
675                if len == 0 {
676                    break;
677                }
678                let buf2 = &buf2[..len];
679
680                crate::traffic_status::traffic_status_update(0, len)?;
681
682                if let ProxyType::Socks4 | ProxyType::Socks5 = proxy_type {
683                    // Remove SOCKS5 UDP header from the server data
684                    let header = UdpHeader::retrieve_from_stream(&mut &buf2[..])?;
685                    let data = &buf2[header.len()..];
686
687                    let buf = if session_info.dst.port() == DNS_PORT {
688                        let mut message = dns::parse_data_to_dns_message(data, false)?;
689                        if !ipv6_enabled {
690                            dns::remove_ipv6_entries(&mut message);
691                        }
692                        message.to_vec()?
693                    } else {
694                        data.to_vec()
695                    };
696
697                    udp_stack.write_all(&buf).await?;
698                } else {
699                    udp_stack.write_all(buf2).await?;
700                }
701            }
702        }
703    }
704
705    log::info!("Ending {}", session_info);
706
707    Ok(())
708}
709
710async fn handle_dns_over_tcp_session(
711    mut udp_stack: IpStackUdpStream,
712    proxy_handler: Arc<Mutex<dyn ProxyHandler>>,
713    socket_queue: Option<Arc<SocketQueue>>,
714    ipv6_enabled: bool,
715) -> crate::Result<()> {
716    let (session_info, server_addr) = {
717        let handler = proxy_handler.lock().await;
718
719        (handler.get_session_info(), handler.get_server_addr())
720    };
721
722    let mut server = create_tcp_stream(&socket_queue, server_addr).await?;
723
724    log::info!("Beginning {}", session_info);
725
726    let _ = handle_proxy_session(&mut server, proxy_handler).await?;
727
728    let mut buf1 = [0_u8; 4096];
729    let mut buf2 = [0_u8; 4096];
730    loop {
731        tokio::select! {
732            len = udp_stack.read(&mut buf1) => {
733                let len = len?;
734                if len == 0 {
735                    break;
736                }
737                let buf1 = &buf1[..len];
738
739                _ = dns::parse_data_to_dns_message(buf1, false)?;
740
741                // Insert the DNS message length in front of the payload
742                let len = u16::try_from(buf1.len())?;
743                let mut buf = Vec::with_capacity(std::mem::size_of::<u16>() + usize::from(len));
744                buf.extend_from_slice(&len.to_be_bytes());
745                buf.extend_from_slice(buf1);
746
747                server.write_all(&buf).await?;
748
749                crate::traffic_status::traffic_status_update(buf.len(), 0)?;
750            }
751            len = server.read(&mut buf2) => {
752                let len = len?;
753                if len == 0 {
754                    break;
755                }
756                let mut buf = buf2[..len].to_vec();
757
758                crate::traffic_status::traffic_status_update(0, len)?;
759
760                let mut to_send: VecDeque<Vec<u8>> = VecDeque::new();
761                loop {
762                    if buf.len() < 2 {
763                        break;
764                    }
765                    let len = u16::from_be_bytes([buf[0], buf[1]]) as usize;
766                    if buf.len() < len + 2 {
767                        break;
768                    }
769
770                    // remove the length field
771                    let data = buf[2..len + 2].to_vec();
772
773                    let mut message = dns::parse_data_to_dns_message(&data, false)?;
774
775                    let name = dns::extract_domain_from_dns_message(&message)?;
776                    let ip = dns::extract_ipaddr_from_dns_message(&message);
777                    log::trace!("DNS over TCP query result: {} -> {:?}", name, ip);
778
779                    if !ipv6_enabled {
780                        dns::remove_ipv6_entries(&mut message);
781                    }
782
783                    to_send.push_back(message.to_vec()?);
784                    if len + 2 == buf.len() {
785                        break;
786                    }
787                    buf = buf[len + 2..].to_vec();
788                }
789
790                while let Some(packet) = to_send.pop_front() {
791                    udp_stack.write_all(&packet).await?;
792                }
793            }
794        }
795    }
796
797    log::info!("Ending {}", session_info);
798
799    Ok(())
800}
801
802/// This function is used to handle the business logic of tun2proxy and SOCKS5 server.
803/// When handling UDP proxy, the return value UDP associate IP address is the result of this business logic.
804/// However, when handling TCP business logic, the return value Ok(None) is meaningless, just indicating that the operation was successful.
805async fn handle_proxy_session(server: &mut TcpStream, proxy_handler: Arc<Mutex<dyn ProxyHandler>>) -> crate::Result<Option<SocketAddr>> {
806    let mut launched = false;
807    let mut proxy_handler = proxy_handler.lock().await;
808    let dir = OutgoingDirection::ToServer;
809    let (mut tx, mut rx) = (0, 0);
810
811    loop {
812        if proxy_handler.connection_established() {
813            break;
814        }
815
816        if !launched {
817            let data = proxy_handler.peek_data(dir).buffer;
818            let len = data.len();
819            if len == 0 {
820                return Err("proxy_handler launched went wrong".into());
821            }
822            server.write_all(data).await?;
823            proxy_handler.consume_data(dir, len);
824            tx += len;
825
826            launched = true;
827        }
828
829        let mut buf = [0_u8; 4096];
830        let len = server.read(&mut buf).await?;
831        if len == 0 {
832            return Err("server closed accidentially".into());
833        }
834        rx += len;
835        let event = IncomingDataEvent {
836            direction: IncomingDirection::FromServer,
837            buffer: &buf[..len],
838        };
839        proxy_handler.push_data(event).await?;
840
841        let data = proxy_handler.peek_data(dir).buffer;
842        let len = data.len();
843        if len > 0 {
844            server.write_all(data).await?;
845            proxy_handler.consume_data(dir, len);
846            tx += len;
847        }
848    }
849    crate::traffic_status::traffic_status_update(tx, rx)?;
850    Ok(proxy_handler.get_udp_associate())
851}