port_mapping/
udp_proxy.rs

1use dashmap::DashMap;
2use std::{fmt::Display, sync::Arc, time::Duration};
3use tokio::{
4    net::UdpSocket,
5    select,
6    sync::mpsc::{self, Sender},
7};
8
9#[derive(Debug)]
10pub struct UdpProxy {
11    pub listen: String,
12    pub upstream: String,
13    pub buffer_size: usize,
14}
15
16impl Display for UdpProxy {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        write!(f, "{}->{}", self.listen, self.upstream)
19    }
20}
21
22impl UdpProxy {
23    pub fn new(listen: String, upstream: String, buffer_size: usize) -> Self {
24        Self {
25            listen,
26            upstream,
27            buffer_size,
28        }
29    }
30
31    pub async fn run(self: Arc<Self>) -> std::io::Result<()> {
32        let server = Arc::new(UdpSocket::bind(&self.listen).await?);
33        println!("[info][udp][{self}] Listening");
34        let map: Arc<DashMap<_, Sender<Vec<u8>>>> = Arc::new(DashMap::new());
35        let mut buf = Vec::with_capacity(self.buffer_size);
36        unsafe {
37            buf.set_len(self.buffer_size);
38        }
39        loop {
40            let (len, addr) = match server.recv_from(&mut buf).await {
41                Ok(res) => res,
42                Err(e) => {
43                    eprintln!("[warning][udp][{self}] Failed to recv from downstream: {e}");
44                    continue;
45                }
46            };
47            match map.get(&addr) {
48                Some(tx) => {
49                    if let Err(e) = tx.send(buf[..len].to_vec()).await {
50                        eprintln!("[warning][udp][{self}] Tokio channel error: {e}");
51                        continue;
52                    }
53                }
54                None => {
55                    let (tx, mut rx) = mpsc::channel(100);
56                    let self_clone = self.clone();
57                    let server_clone = server.clone();
58                    if let Err(e) = tx.send(buf[..len].to_vec()).await {
59                        eprintln!("[warning][udp][{self}] Tokio channel error: {e}");
60                    }
61                    map.insert(addr, tx);
62                    let map_clone = map.clone();
63                    tokio::spawn(async move {
64                        let client = match UdpSocket::bind("127.0.0.1:0").await {
65                            Ok(client) => client,
66                            Err(e) => {
67                                eprintln!(
68                                    "[warning][udp][{self_clone}] Failed to bind client socket: {e}"
69                                );
70                                return;
71                            }
72                        };
73                        if let Err(e) = client.connect(&self_clone.upstream).await {
74                            eprintln!(
75                                "[warning][udp][{self_clone}] Failed to connect to upstream: {e}"
76                            );
77                            return;
78                        };
79                        let mut buf = Vec::with_capacity(self_clone.buffer_size);
80                        unsafe {
81                            buf.set_len(self_clone.buffer_size);
82                        }
83                        loop {
84                            select! {
85                                Some(received) = rx.recv() => {
86                                    if let Err(e) = client.send(&received).await {
87                                        eprintln!(
88                                            "[warning][udp][{self_clone}] Failed to send to upstream: {e}"
89                                        );
90                                    }
91                                }
92                                Ok(len) = client.recv(&mut buf) => {
93                                    if let Err(e) = server_clone.send_to(&buf[..len], &addr).await {
94                                        eprintln!(
95                                            "[warning][udp][{self_clone}] Failed to send to downstream: {e}"
96                                        );
97                                    }
98                                }
99                                _ = tokio::time::sleep(Duration::from_secs(60)) => {
100                                    println!(
101                                        "[info][udp][{self_clone}] No data transport for 60 seconds, closing connection"
102                                    );
103                                    break;
104                                }
105                            }
106                        }
107                        map_clone.remove(&addr);
108                    });
109                }
110            }
111        }
112    }
113}