Skip to main content

relay_core_lib/capture/
udp.rs

1#[allow(unused_imports)]
2use chrono::Utc;
3#[allow(unused_imports)]
4use relay_core_api::flow::{Flow, FlowUpdate, Layer, NetworkInfo, TransportProtocol, UdpLayer};
5use std::collections::HashMap;
6use std::net::{IpAddr, SocketAddr};
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::net::UdpSocket;
10use tokio::sync::RwLock;
11use tokio::sync::mpsc::Sender;
12use uuid::Uuid;
13
14#[cfg(target_os = "linux")]
15use crate::capture::linux_tproxy::LinuxTproxy;
16
17use std::sync::atomic::{AtomicUsize, Ordering};
18
19/// Key for UDP session (5-tuple)
20#[derive(Debug, Clone, PartialEq, Eq, Hash)]
21pub struct UdpSessionKey {
22    pub src_ip: IpAddr,
23    pub src_port: u16,
24    pub dst_ip: IpAddr,
25    pub dst_port: u16,
26    // Protocol is implicitly UDP
27}
28
29impl UdpSessionKey {
30    pub fn new(src: SocketAddr, dst: SocketAddr) -> Self {
31        Self {
32            src_ip: src.ip(),
33            src_port: src.port(),
34            dst_ip: dst.ip(),
35            dst_port: dst.port(),
36        }
37    }
38}
39
40/// UDP Session Metadata
41#[derive(Debug, Clone)]
42pub struct UdpSession {
43    pub flow_id: Uuid,
44    pub key: UdpSessionKey,
45    pub created_at: Instant,
46    pub last_activity: Arc<RwLock<Instant>>,
47    pub packet_count: Arc<AtomicUsize>,
48    pub bytes_transferred: Arc<AtomicUsize>,
49    #[cfg(target_os = "linux")]
50    pub upstream_socket: Option<Arc<UdpSocket>>, // Bound to src, connected to dst
51    #[cfg(target_os = "linux")]
52    pub downstream_socket: Option<Arc<UdpSocket>>, // Bound to dst, connected to src
53}
54
55/// Manager for tracking active UDP sessions
56pub struct UdpSessionManager {
57    sessions: RwLock<HashMap<UdpSessionKey, UdpSession>>,
58    idle_timeout: Duration,
59}
60
61impl UdpSessionManager {
62    pub fn new(idle_timeout: Duration) -> Self {
63        Self {
64            sessions: RwLock::new(HashMap::new()),
65            idle_timeout,
66        }
67    }
68
69    /// Get existing session or create new one
70    /// Returns (session, is_new)
71    pub async fn get_or_create_session(
72        &self,
73        src: SocketAddr,
74        dst: SocketAddr,
75    ) -> std::io::Result<(UdpSession, bool)> {
76        let key = UdpSessionKey::new(src, dst);
77        // Fast path: read lock
78        {
79            let sessions = self.sessions.read().await;
80            if let Some(session) = sessions.get(&key) {
81                let mut last = session.last_activity.write().await;
82                *last = Instant::now();
83                session.packet_count.fetch_add(1, Ordering::Relaxed);
84                return Ok((session.clone(), false));
85            }
86        }
87
88        // Slow path: write lock
89        let mut sessions = self.sessions.write().await;
90        // Check again
91        if let Some(session) = sessions.get(&key) {
92            let mut last = session.last_activity.write().await;
93            *last = Instant::now();
94            session.packet_count.fetch_add(1, Ordering::Relaxed);
95            return Ok((session.clone(), false));
96        }
97
98        #[cfg(target_os = "linux")]
99        let (upstream, downstream) = {
100            // Create upstream socket: Bound to src, connect to dst
101            let up = LinuxTproxy::create_transparent_udp_socket(src)?;
102            up.connect(dst).await?;
103
104            // Create downstream socket: Bound to dst, connect to src
105            let down = LinuxTproxy::create_transparent_udp_socket(dst)?;
106            down.connect(src).await?;
107
108            (Some(Arc::new(up)), Some(Arc::new(down)))
109        };
110
111        // Create new session
112        let session = UdpSession {
113            flow_id: Uuid::new_v4(),
114            key: key.clone(),
115            created_at: Instant::now(),
116            last_activity: Arc::new(RwLock::new(Instant::now())),
117            packet_count: Arc::new(AtomicUsize::new(1)),
118            bytes_transferred: Arc::new(AtomicUsize::new(0)),
119            #[cfg(target_os = "linux")]
120            upstream_socket: upstream,
121            #[cfg(target_os = "linux")]
122            downstream_socket: downstream,
123        };
124
125        // Spawn reverse proxy task (B -> A)
126        #[cfg(target_os = "linux")]
127        if let (Some(up), Some(down)) = (&session.upstream_socket, &session.downstream_socket) {
128            let up_clone = up.clone();
129            let down_clone = down.clone();
130            let last_activity = session.last_activity.clone();
131            let bytes_transferred = session.bytes_transferred.clone();
132
133            tokio::spawn(async move {
134                let mut buf = [0u8; 65535];
135                loop {
136                    // Read from upstream (response from Server B)
137                    match up_clone.recv(&mut buf).await {
138                        Ok(n) => {
139                            // Update activity
140                            if let Ok(mut last) = last_activity.try_write() {
141                                *last = Instant::now();
142                            }
143                            bytes_transferred.fetch_add(n, Ordering::Relaxed);
144
145                            // Send to downstream (to Client A)
146                            if let Err(e) = down_clone.send(&buf[..n]).await {
147                                tracing::debug!("UDP downstream send error: {}", e);
148                                break;
149                            }
150                        }
151                        Err(e) => {
152                            tracing::debug!("UDP upstream recv error: {}", e);
153                            break;
154                        }
155                    }
156                }
157            });
158        }
159
160        sessions.insert(key, session.clone());
161        Ok((session, true))
162    }
163
164    /// Clean up idle sessions
165    pub async fn cleanup_idle_sessions(&self) -> Vec<Uuid> {
166        let mut sessions = self.sessions.write().await;
167        let now = Instant::now();
168        let mut removed_ids = Vec::new();
169        let mut keys_to_remove = Vec::new();
170
171        // Identify idle sessions
172        for (key, session) in sessions.iter() {
173            let last = *session.last_activity.read().await;
174            if now.duration_since(last) > self.idle_timeout {
175                removed_ids.push(session.flow_id);
176                keys_to_remove.push(key.clone());
177            }
178        }
179
180        // Remove them
181        for key in keys_to_remove {
182            sessions.remove(&key);
183        }
184
185        removed_ids
186    }
187}
188
189/// UDP Proxy capable of handling multiple sessions
190pub struct UdpProxy {
191    socket: Arc<UdpSocket>,
192    session_manager: Arc<UdpSessionManager>,
193    remote_addr: Option<SocketAddr>,
194}
195
196impl UdpProxy {
197    pub fn new(socket: UdpSocket, idle_timeout: Duration) -> Self {
198        Self {
199            socket: Arc::new(socket),
200            session_manager: Arc::new(UdpSessionManager::new(idle_timeout)),
201            remote_addr: None,
202        }
203    }
204
205    pub fn with_remote(mut self, addr: SocketAddr) -> Self {
206        self.remote_addr = Some(addr);
207        self
208    }
209
210    /// Run the proxy loop
211    pub async fn run(&self, on_flow: Sender<FlowUpdate>) -> crate::error::Result<()> {
212        let mut buf = [0u8; 65535];
213
214        #[cfg(target_os = "linux")]
215        {
216            // Enable TPROXY on socket
217            LinuxTproxy::enable_tproxy(&self.socket)?;
218
219            loop {
220                // Use recv_original_dst
221                let (len, src_addr, orig_dst) =
222                    match LinuxTproxy::recv_original_dst(&self.socket, &mut buf).await {
223                        Ok(res) => res,
224                        Err(e) => {
225                            tracing::error!("UDP TPROXY recv error: {}", e);
226                            continue;
227                        }
228                    };
229
230                if let Some(dst_addr) = orig_dst {
231                    match self
232                        .session_manager
233                        .get_or_create_session(src_addr, dst_addr)
234                        .await
235                    {
236                        Ok((session, is_new)) => {
237                            if is_new {
238                                // Create initial flow
239                                let flow = Flow {
240                                    id: session.flow_id,
241                                    start_time: Utc::now(),
242                                    end_time: None,
243                                    network: NetworkInfo {
244                                        client_ip: src_addr.ip().to_string(),
245                                        client_port: src_addr.port(),
246                                        server_ip: dst_addr.ip().to_string(),
247                                        server_port: dst_addr.port(),
248                                        protocol: TransportProtocol::UDP,
249                                        tls: false,
250                                        tls_version: None,
251                                        sni: None,
252                                    },
253                                    layer: Layer::Udp(UdpLayer {
254                                        payload_size: len,
255                                        packet_count: 1,
256                                    }),
257                                    tags: vec![],
258                                    meta: HashMap::new(),
259                                    resilience_trace: None,
260                                    rule_variables: HashMap::new(),
261                                    matched_rules: vec![],
262                                };
263                                if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
264                                    crate::metrics::inc_flows_dropped();
265                                }
266                            }
267
268                            // Forward packet logic (A -> B)
269                            // Using upstream socket bound to src_addr
270                            if let Some(upstream) = &session.upstream_socket {
271                                if let Err(e) = upstream.send(&buf[..len]).await {
272                                    tracing::debug!("UDP upstream send error: {}", e);
273                                } else {
274                                    session.bytes_transferred.fetch_add(len, Ordering::Relaxed);
275                                }
276                            }
277                        }
278                        Err(e) => {
279                            tracing::warn!("Failed to create UDP session: {}", e);
280                        }
281                    }
282                }
283            }
284        }
285
286        #[cfg(not(target_os = "linux"))]
287        {
288            let remote_addr = match self.remote_addr {
289                Some(addr) => addr,
290                None => {
291                    tracing::warn!(
292                        "UDP proxy started without remote_addr on non-Linux; no forwarding"
293                    );
294                    loop {
295                        match self.socket.recv_from(&mut buf).await {
296                            Ok((_len, _src_addr)) => {}
297                            Err(e) => {
298                                tracing::error!("UDP drain recv error: {}", e);
299                                continue;
300                            }
301                        }
302                    }
303                }
304            };
305
306            let sm = self.session_manager.clone();
307            let sock = self.socket.clone();
308            let flow_tx = on_flow;
309
310            loop {
311                let (len, src_addr) = match sock.recv_from(&mut buf).await {
312                    Ok(res) => res,
313                    Err(e) => {
314                        tracing::error!("UDP recv error: {}", e);
315                        continue;
316                    }
317                };
318
319                let (session, is_new) = match sm.get_or_create_session(src_addr, remote_addr).await
320                {
321                    Ok(res) => res,
322                    Err(e) => {
323                        tracing::warn!("Failed to create UDP session: {}", e);
324                        continue;
325                    }
326                };
327
328                if is_new {
329                    let flow = Flow {
330                        id: session.flow_id,
331                        start_time: Utc::now(),
332                        end_time: None,
333                        network: NetworkInfo {
334                            client_ip: src_addr.ip().to_string(),
335                            client_port: src_addr.port(),
336                            server_ip: remote_addr.ip().to_string(),
337                            server_port: remote_addr.port(),
338                            protocol: TransportProtocol::UDP,
339                            tls: false,
340                            tls_version: None,
341                            sni: None,
342                        },
343                        layer: Layer::Udp(UdpLayer {
344                            payload_size: len,
345                            packet_count: 1,
346                        }),
347                        tags: vec![],
348                        meta: HashMap::new(),
349                        resilience_trace: None,
350                        rule_variables: HashMap::new(),
351                        matched_rules: vec![],
352                    };
353                    let _ = flow_tx.try_send(FlowUpdate::Full(Box::new(flow)));
354
355                    let sock_clone = sock.clone();
356                    let bytes = session.bytes_transferred.clone();
357                    let last = session.last_activity.clone();
358                    tokio::spawn(async move {
359                        let mut rbuf = [0u8; 65535];
360                        loop {
361                            match sock_clone.recv_from(&mut rbuf).await {
362                                Ok((n, addr)) => {
363                                    if addr == remote_addr {
364                                        let _ = sock_clone.send_to(&rbuf[..n], src_addr).await;
365                                        bytes.fetch_add(n, Ordering::Relaxed);
366                                        if let Ok(mut la) = last.try_write() {
367                                            *la = Instant::now();
368                                        }
369                                    }
370                                }
371                                Err(e) => {
372                                    tracing::debug!(
373                                        "UDP reverse recv error for {}: {}",
374                                        session.flow_id,
375                                        e
376                                    );
377                                    break;
378                                }
379                            }
380                        }
381                    });
382                }
383
384                match sock.send_to(&buf[..len], remote_addr).await {
385                    Ok(_) => {
386                        session.bytes_transferred.fetch_add(len, Ordering::Relaxed);
387                    }
388                    Err(e) => {
389                        tracing::debug!("UDP send_to {} error: {}", remote_addr, e);
390                    }
391                }
392            }
393        }
394    }
395}