1#[allow(unused_imports)]
2use chrono::Utc;
3#[allow(unused_imports)]
4use relay_core_api::flow::{Flow, FlowUpdate, Layer, NetworkInfo, TransportProtocol, UdpLayer};
5use std::collections::HashMap;
6use std::net::{IpAddr, SocketAddr};
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::net::UdpSocket;
10use tokio::sync::RwLock;
11use tokio::sync::mpsc::Sender;
12use uuid::Uuid;
13
14#[cfg(target_os = "linux")]
15use crate::capture::linux_tproxy::LinuxTproxy;
16
17use std::sync::atomic::{AtomicUsize, Ordering};
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
21pub struct UdpSessionKey {
22 pub src_ip: IpAddr,
23 pub src_port: u16,
24 pub dst_ip: IpAddr,
25 pub dst_port: u16,
26 }
28
29impl UdpSessionKey {
30 pub fn new(src: SocketAddr, dst: SocketAddr) -> Self {
31 Self {
32 src_ip: src.ip(),
33 src_port: src.port(),
34 dst_ip: dst.ip(),
35 dst_port: dst.port(),
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct UdpSession {
43 pub flow_id: Uuid,
44 pub key: UdpSessionKey,
45 pub created_at: Instant,
46 pub last_activity: Arc<RwLock<Instant>>,
47 pub packet_count: Arc<AtomicUsize>,
48 pub bytes_transferred: Arc<AtomicUsize>,
49 #[cfg(target_os = "linux")]
50 pub upstream_socket: Option<Arc<UdpSocket>>, #[cfg(target_os = "linux")]
52 pub downstream_socket: Option<Arc<UdpSocket>>, }
54
55pub struct UdpSessionManager {
57 sessions: RwLock<HashMap<UdpSessionKey, UdpSession>>,
58 idle_timeout: Duration,
59}
60
61impl UdpSessionManager {
62 pub fn new(idle_timeout: Duration) -> Self {
63 Self {
64 sessions: RwLock::new(HashMap::new()),
65 idle_timeout,
66 }
67 }
68
69 pub async fn get_or_create_session(
72 &self,
73 src: SocketAddr,
74 dst: SocketAddr,
75 ) -> std::io::Result<(UdpSession, bool)> {
76 let key = UdpSessionKey::new(src, dst);
77 {
79 let sessions = self.sessions.read().await;
80 if let Some(session) = sessions.get(&key) {
81 let mut last = session.last_activity.write().await;
82 *last = Instant::now();
83 session.packet_count.fetch_add(1, Ordering::Relaxed);
84 return Ok((session.clone(), false));
85 }
86 }
87
88 let mut sessions = self.sessions.write().await;
90 if let Some(session) = sessions.get(&key) {
92 let mut last = session.last_activity.write().await;
93 *last = Instant::now();
94 session.packet_count.fetch_add(1, Ordering::Relaxed);
95 return Ok((session.clone(), false));
96 }
97
98 #[cfg(target_os = "linux")]
99 let (upstream, downstream) = {
100 let up = LinuxTproxy::create_transparent_udp_socket(src)?;
102 up.connect(dst).await?;
103
104 let down = LinuxTproxy::create_transparent_udp_socket(dst)?;
106 down.connect(src).await?;
107
108 (Some(Arc::new(up)), Some(Arc::new(down)))
109 };
110
111 let session = UdpSession {
113 flow_id: Uuid::new_v4(),
114 key: key.clone(),
115 created_at: Instant::now(),
116 last_activity: Arc::new(RwLock::new(Instant::now())),
117 packet_count: Arc::new(AtomicUsize::new(1)),
118 bytes_transferred: Arc::new(AtomicUsize::new(0)),
119 #[cfg(target_os = "linux")]
120 upstream_socket: upstream,
121 #[cfg(target_os = "linux")]
122 downstream_socket: downstream,
123 };
124
125 #[cfg(target_os = "linux")]
127 if let (Some(up), Some(down)) = (&session.upstream_socket, &session.downstream_socket) {
128 let up_clone = up.clone();
129 let down_clone = down.clone();
130 let last_activity = session.last_activity.clone();
131 let bytes_transferred = session.bytes_transferred.clone();
132
133 tokio::spawn(async move {
134 let mut buf = [0u8; 65535];
135 loop {
136 match up_clone.recv(&mut buf).await {
138 Ok(n) => {
139 if let Ok(mut last) = last_activity.try_write() {
141 *last = Instant::now();
142 }
143 bytes_transferred.fetch_add(n, Ordering::Relaxed);
144
145 if let Err(e) = down_clone.send(&buf[..n]).await {
147 tracing::debug!("UDP downstream send error: {}", e);
148 break;
149 }
150 }
151 Err(e) => {
152 tracing::debug!("UDP upstream recv error: {}", e);
153 break;
154 }
155 }
156 }
157 });
158 }
159
160 sessions.insert(key, session.clone());
161 Ok((session, true))
162 }
163
164 pub async fn cleanup_idle_sessions(&self) -> Vec<Uuid> {
166 let mut sessions = self.sessions.write().await;
167 let now = Instant::now();
168 let mut removed_ids = Vec::new();
169 let mut keys_to_remove = Vec::new();
170
171 for (key, session) in sessions.iter() {
173 let last = *session.last_activity.read().await;
174 if now.duration_since(last) > self.idle_timeout {
175 removed_ids.push(session.flow_id);
176 keys_to_remove.push(key.clone());
177 }
178 }
179
180 for key in keys_to_remove {
182 sessions.remove(&key);
183 }
184
185 removed_ids
186 }
187}
188
189pub struct UdpProxy {
191 socket: Arc<UdpSocket>,
192 session_manager: Arc<UdpSessionManager>,
193 remote_addr: Option<SocketAddr>,
194}
195
196impl UdpProxy {
197 pub fn new(socket: UdpSocket, idle_timeout: Duration) -> Self {
198 Self {
199 socket: Arc::new(socket),
200 session_manager: Arc::new(UdpSessionManager::new(idle_timeout)),
201 remote_addr: None,
202 }
203 }
204
205 pub fn with_remote(mut self, addr: SocketAddr) -> Self {
206 self.remote_addr = Some(addr);
207 self
208 }
209
210 pub async fn run(&self, on_flow: Sender<FlowUpdate>) -> crate::error::Result<()> {
212 let mut buf = [0u8; 65535];
213
214 #[cfg(target_os = "linux")]
215 {
216 LinuxTproxy::enable_tproxy(&self.socket)?;
218
219 loop {
220 let (len, src_addr, orig_dst) =
222 match LinuxTproxy::recv_original_dst(&self.socket, &mut buf).await {
223 Ok(res) => res,
224 Err(e) => {
225 tracing::error!("UDP TPROXY recv error: {}", e);
226 continue;
227 }
228 };
229
230 if let Some(dst_addr) = orig_dst {
231 match self
232 .session_manager
233 .get_or_create_session(src_addr, dst_addr)
234 .await
235 {
236 Ok((session, is_new)) => {
237 if is_new {
238 let flow = Flow {
240 id: session.flow_id,
241 start_time: Utc::now(),
242 end_time: None,
243 network: NetworkInfo {
244 client_ip: src_addr.ip().to_string(),
245 client_port: src_addr.port(),
246 server_ip: dst_addr.ip().to_string(),
247 server_port: dst_addr.port(),
248 protocol: TransportProtocol::UDP,
249 tls: false,
250 tls_version: None,
251 sni: None,
252 },
253 layer: Layer::Udp(UdpLayer {
254 payload_size: len,
255 packet_count: 1,
256 }),
257 tags: vec![],
258 meta: HashMap::new(),
259 resilience_trace: None,
260 rule_variables: HashMap::new(),
261 matched_rules: vec![],
262 };
263 if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
264 crate::metrics::inc_flows_dropped();
265 }
266 }
267
268 if let Some(upstream) = &session.upstream_socket {
271 if let Err(e) = upstream.send(&buf[..len]).await {
272 tracing::debug!("UDP upstream send error: {}", e);
273 } else {
274 session.bytes_transferred.fetch_add(len, Ordering::Relaxed);
275 }
276 }
277 }
278 Err(e) => {
279 tracing::warn!("Failed to create UDP session: {}", e);
280 }
281 }
282 }
283 }
284 }
285
286 #[cfg(not(target_os = "linux"))]
287 {
288 let remote_addr = match self.remote_addr {
289 Some(addr) => addr,
290 None => {
291 tracing::warn!(
292 "UDP proxy started without remote_addr on non-Linux; no forwarding"
293 );
294 loop {
295 match self.socket.recv_from(&mut buf).await {
296 Ok((_len, _src_addr)) => {}
297 Err(e) => {
298 tracing::error!("UDP drain recv error: {}", e);
299 continue;
300 }
301 }
302 }
303 }
304 };
305
306 let sm = self.session_manager.clone();
307 let sock = self.socket.clone();
308 let flow_tx = on_flow;
309
310 loop {
311 let (len, src_addr) = match sock.recv_from(&mut buf).await {
312 Ok(res) => res,
313 Err(e) => {
314 tracing::error!("UDP recv error: {}", e);
315 continue;
316 }
317 };
318
319 let (session, is_new) = match sm.get_or_create_session(src_addr, remote_addr).await
320 {
321 Ok(res) => res,
322 Err(e) => {
323 tracing::warn!("Failed to create UDP session: {}", e);
324 continue;
325 }
326 };
327
328 if is_new {
329 let flow = Flow {
330 id: session.flow_id,
331 start_time: Utc::now(),
332 end_time: None,
333 network: NetworkInfo {
334 client_ip: src_addr.ip().to_string(),
335 client_port: src_addr.port(),
336 server_ip: remote_addr.ip().to_string(),
337 server_port: remote_addr.port(),
338 protocol: TransportProtocol::UDP,
339 tls: false,
340 tls_version: None,
341 sni: None,
342 },
343 layer: Layer::Udp(UdpLayer {
344 payload_size: len,
345 packet_count: 1,
346 }),
347 tags: vec![],
348 meta: HashMap::new(),
349 resilience_trace: None,
350 rule_variables: HashMap::new(),
351 matched_rules: vec![],
352 };
353 let _ = flow_tx.try_send(FlowUpdate::Full(Box::new(flow)));
354
355 let sock_clone = sock.clone();
356 let bytes = session.bytes_transferred.clone();
357 let last = session.last_activity.clone();
358 tokio::spawn(async move {
359 let mut rbuf = [0u8; 65535];
360 loop {
361 match sock_clone.recv_from(&mut rbuf).await {
362 Ok((n, addr)) => {
363 if addr == remote_addr {
364 let _ = sock_clone.send_to(&rbuf[..n], src_addr).await;
365 bytes.fetch_add(n, Ordering::Relaxed);
366 if let Ok(mut la) = last.try_write() {
367 *la = Instant::now();
368 }
369 }
370 }
371 Err(e) => {
372 tracing::debug!(
373 "UDP reverse recv error for {}: {}",
374 session.flow_id,
375 e
376 );
377 break;
378 }
379 }
380 }
381 });
382 }
383
384 match sock.send_to(&buf[..len], remote_addr).await {
385 Ok(_) => {
386 session.bytes_transferred.fetch_add(len, Ordering::Relaxed);
387 }
388 Err(e) => {
389 tracing::debug!("UDP send_to {} error: {}", remote_addr, e);
390 }
391 }
392 }
393 }
394 }
395}