1use std::net::{SocketAddr, IpAddr};
2use std::collections::HashMap;
3use std::time::{Duration, Instant};
4use tokio::sync::RwLock;
5use tokio::sync::mpsc::Sender;
6use uuid::Uuid;
7use tokio::net::UdpSocket;
8use std::sync::Arc;
9#[allow(unused_imports)]
10use relay_core_api::flow::{Flow, FlowUpdate, NetworkInfo, TransportProtocol, Layer, UdpLayer};
11#[allow(unused_imports)]
12use chrono::Utc;
13
14#[cfg(target_os = "linux")]
15use crate::capture::linux_tproxy::LinuxTproxy;
16
17use std::sync::atomic::{AtomicUsize, Ordering};
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
21pub struct UdpSessionKey {
22 pub src_ip: IpAddr,
23 pub src_port: u16,
24 pub dst_ip: IpAddr,
25 pub dst_port: u16,
26 }
28
29impl UdpSessionKey {
30 pub fn new(src: SocketAddr, dst: SocketAddr) -> Self {
31 Self {
32 src_ip: src.ip(),
33 src_port: src.port(),
34 dst_ip: dst.ip(),
35 dst_port: dst.port(),
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct UdpSession {
43 pub flow_id: Uuid,
44 pub key: UdpSessionKey,
45 pub created_at: Instant,
46 pub last_activity: Arc<RwLock<Instant>>,
47 pub packet_count: Arc<AtomicUsize>,
48 pub bytes_transferred: Arc<AtomicUsize>,
49 #[cfg(target_os = "linux")]
50 pub upstream_socket: Option<Arc<UdpSocket>>, #[cfg(target_os = "linux")]
52 pub downstream_socket: Option<Arc<UdpSocket>>, }
54
55pub struct UdpSessionManager {
57 sessions: RwLock<HashMap<UdpSessionKey, UdpSession>>,
58 idle_timeout: Duration,
59}
60
61impl UdpSessionManager {
62 pub fn new(idle_timeout: Duration) -> Self {
63 Self {
64 sessions: RwLock::new(HashMap::new()),
65 idle_timeout,
66 }
67 }
68
69 pub async fn get_or_create_session(&self, src: SocketAddr, dst: SocketAddr) -> std::io::Result<(UdpSession, bool)> {
72 let key = UdpSessionKey::new(src, dst);
73 {
75 let sessions = self.sessions.read().await;
76 if let Some(session) = sessions.get(&key) {
77 let mut last = session.last_activity.write().await;
78 *last = Instant::now();
79 session.packet_count.fetch_add(1, Ordering::Relaxed);
80 return Ok((session.clone(), false));
81 }
82 }
83
84 let mut sessions = self.sessions.write().await;
86 if let Some(session) = sessions.get(&key) {
88 let mut last = session.last_activity.write().await;
89 *last = Instant::now();
90 session.packet_count.fetch_add(1, Ordering::Relaxed);
91 return Ok((session.clone(), false));
92 }
93
94 #[cfg(target_os = "linux")]
95 let (upstream, downstream) = {
96 let up = LinuxTproxy::create_transparent_udp_socket(src)?;
98 up.connect(dst).await?;
99
100 let down = LinuxTproxy::create_transparent_udp_socket(dst)?;
102 down.connect(src).await?;
103
104 (Some(Arc::new(up)), Some(Arc::new(down)))
105 };
106
107 let session = UdpSession {
109 flow_id: Uuid::new_v4(),
110 key: key.clone(),
111 created_at: Instant::now(),
112 last_activity: Arc::new(RwLock::new(Instant::now())),
113 packet_count: Arc::new(AtomicUsize::new(1)),
114 bytes_transferred: Arc::new(AtomicUsize::new(0)),
115 #[cfg(target_os = "linux")]
116 upstream_socket: upstream,
117 #[cfg(target_os = "linux")]
118 downstream_socket: downstream,
119 };
120
121 #[cfg(target_os = "linux")]
123 if let (Some(up), Some(down)) = (&session.upstream_socket, &session.downstream_socket) {
124 let up_clone = up.clone();
125 let down_clone = down.clone();
126 let last_activity = session.last_activity.clone();
127 let bytes_transferred = session.bytes_transferred.clone();
128
129 tokio::spawn(async move {
130 let mut buf = [0u8; 65535];
131 loop {
132 match up_clone.recv(&mut buf).await {
134 Ok(n) => {
135 if let Ok(mut last) = last_activity.try_write() {
137 *last = Instant::now();
138 }
139 bytes_transferred.fetch_add(n, Ordering::Relaxed);
140
141 if let Err(e) = down_clone.send(&buf[..n]).await {
143 tracing::debug!("UDP downstream send error: {}", e);
144 break;
145 }
146 }
147 Err(e) => {
148 tracing::debug!("UDP upstream recv error: {}", e);
149 break;
150 }
151 }
152 }
153 });
154 }
155
156 sessions.insert(key, session.clone());
157 Ok((session, true))
158 }
159
160 pub async fn cleanup_idle_sessions(&self) -> Vec<Uuid> {
162 let mut sessions = self.sessions.write().await;
163 let now = Instant::now();
164 let mut removed_ids = Vec::new();
165 let mut keys_to_remove = Vec::new();
166
167 for (key, session) in sessions.iter() {
169 let last = *session.last_activity.read().await;
170 if now.duration_since(last) > self.idle_timeout {
171 removed_ids.push(session.flow_id);
172 keys_to_remove.push(key.clone());
173 }
174 }
175
176 for key in keys_to_remove {
178 sessions.remove(&key);
179 }
180
181 removed_ids
182 }
183}
184
185pub struct UdpProxy {
187 socket: Arc<UdpSocket>,
188 #[allow(dead_code)]
189 session_manager: Arc<UdpSessionManager>,
190}
191
192impl UdpProxy {
193 pub fn new(socket: UdpSocket, idle_timeout: Duration) -> Self {
194 Self {
195 socket: Arc::new(socket),
196 session_manager: Arc::new(UdpSessionManager::new(idle_timeout)),
197 }
198 }
199
200 pub async fn run(&self, on_flow: Sender<FlowUpdate>) -> crate::error::Result<()>
202 {
203 let mut buf = [0u8; 65535];
204
205 #[cfg(target_os = "linux")]
206 {
207 LinuxTproxy::enable_tproxy(&self.socket)?;
209
210 loop {
211 let (len, src_addr, orig_dst) = match LinuxTproxy::recv_original_dst(&self.socket, &mut buf).await {
213 Ok(res) => res,
214 Err(e) => {
215 tracing::error!("UDP TPROXY recv error: {}", e);
216 continue;
217 }
218 };
219
220 if let Some(dst_addr) = orig_dst {
221 match self.session_manager.get_or_create_session(src_addr, dst_addr).await {
222 Ok((session, is_new)) => {
223 if is_new {
224 let flow = Flow {
226 id: session.flow_id,
227 start_time: Utc::now(),
228 end_time: None,
229 network: NetworkInfo {
230 client_ip: src_addr.ip().to_string(),
231 client_port: src_addr.port(),
232 server_ip: dst_addr.ip().to_string(),
233 server_port: dst_addr.port(),
234 protocol: TransportProtocol::UDP,
235 tls: false,
236 tls_version: None,
237 sni: None,
238 },
239 layer: Layer::Udp(UdpLayer {
240 payload_size: len,
241 packet_count: 1,
242 }),
243 tags: vec![],
244 meta: HashMap::new(),
245 };
246 if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
247 crate::metrics::inc_flows_dropped();
248 }
249 }
250
251 if let Some(upstream) = &session.upstream_socket {
254 if let Err(e) = upstream.send(&buf[..len]).await {
255 tracing::debug!("UDP upstream send error: {}", e);
256 } else {
257 session.bytes_transferred.fetch_add(len, Ordering::Relaxed);
258 }
259 }
260 }
261 Err(e) => {
262 tracing::warn!("Failed to create UDP session: {}", e);
263 }
264 }
265 }
266 }
267 }
268
269 #[cfg(not(target_os = "linux"))]
270 {
271 let _ = on_flow;
272 loop {
273 let (_len, _src_addr) = self.socket.recv_from(&mut buf).await?;
274 }
277 }
278 }
279}