Skip to main content

relay_core_lib/capture/
udp.rs

1use std::net::{SocketAddr, IpAddr};
2use std::collections::HashMap;
3use std::time::{Duration, Instant};
4use tokio::sync::RwLock;
5use tokio::sync::mpsc::Sender;
6use uuid::Uuid;
7use tokio::net::UdpSocket;
8use std::sync::Arc;
9#[allow(unused_imports)]
10use relay_core_api::flow::{Flow, FlowUpdate, NetworkInfo, TransportProtocol, Layer, UdpLayer};
11#[allow(unused_imports)]
12use chrono::Utc;
13
14#[cfg(target_os = "linux")]
15use crate::capture::linux_tproxy::LinuxTproxy;
16
17use std::sync::atomic::{AtomicUsize, Ordering};
18
19/// Key for UDP session (5-tuple)
20#[derive(Debug, Clone, PartialEq, Eq, Hash)]
21pub struct UdpSessionKey {
22    pub src_ip: IpAddr,
23    pub src_port: u16,
24    pub dst_ip: IpAddr,
25    pub dst_port: u16,
26    // Protocol is implicitly UDP
27}
28
29impl UdpSessionKey {
30    pub fn new(src: SocketAddr, dst: SocketAddr) -> Self {
31        Self {
32            src_ip: src.ip(),
33            src_port: src.port(),
34            dst_ip: dst.ip(),
35            dst_port: dst.port(),
36        }
37    }
38}
39
40/// UDP Session Metadata
41#[derive(Debug, Clone)]
42pub struct UdpSession {
43    pub flow_id: Uuid,
44    pub key: UdpSessionKey,
45    pub created_at: Instant,
46    pub last_activity: Arc<RwLock<Instant>>,
47    pub packet_count: Arc<AtomicUsize>,
48    pub bytes_transferred: Arc<AtomicUsize>,
49    #[cfg(target_os = "linux")]
50    pub upstream_socket: Option<Arc<UdpSocket>>, // Bound to src, connected to dst
51    #[cfg(target_os = "linux")]
52    pub downstream_socket: Option<Arc<UdpSocket>>, // Bound to dst, connected to src
53}
54
55/// Manager for tracking active UDP sessions
56pub struct UdpSessionManager {
57    sessions: RwLock<HashMap<UdpSessionKey, UdpSession>>,
58    idle_timeout: Duration,
59}
60
61impl UdpSessionManager {
62    pub fn new(idle_timeout: Duration) -> Self {
63        Self {
64            sessions: RwLock::new(HashMap::new()),
65            idle_timeout,
66        }
67    }
68
69    /// Get existing session or create new one
70    /// Returns (session, is_new)
71    pub async fn get_or_create_session(&self, src: SocketAddr, dst: SocketAddr) -> std::io::Result<(UdpSession, bool)> {
72        let key = UdpSessionKey::new(src, dst);
73        // Fast path: read lock
74        {
75            let sessions = self.sessions.read().await;
76            if let Some(session) = sessions.get(&key) {
77                let mut last = session.last_activity.write().await;
78                *last = Instant::now();
79                session.packet_count.fetch_add(1, Ordering::Relaxed);
80                return Ok((session.clone(), false));
81            }
82        }
83
84        // Slow path: write lock
85        let mut sessions = self.sessions.write().await;
86        // Check again
87        if let Some(session) = sessions.get(&key) {
88            let mut last = session.last_activity.write().await;
89            *last = Instant::now();
90            session.packet_count.fetch_add(1, Ordering::Relaxed);
91            return Ok((session.clone(), false));
92        }
93
94        #[cfg(target_os = "linux")]
95        let (upstream, downstream) = {
96            // Create upstream socket: Bound to src, connect to dst
97            let up = LinuxTproxy::create_transparent_udp_socket(src)?;
98            up.connect(dst).await?;
99            
100            // Create downstream socket: Bound to dst, connect to src
101            let down = LinuxTproxy::create_transparent_udp_socket(dst)?;
102            down.connect(src).await?;
103            
104            (Some(Arc::new(up)), Some(Arc::new(down)))
105        };
106
107        // Create new session
108        let session = UdpSession {
109            flow_id: Uuid::new_v4(),
110            key: key.clone(),
111            created_at: Instant::now(),
112            last_activity: Arc::new(RwLock::new(Instant::now())),
113            packet_count: Arc::new(AtomicUsize::new(1)),
114            bytes_transferred: Arc::new(AtomicUsize::new(0)),
115            #[cfg(target_os = "linux")]
116            upstream_socket: upstream,
117            #[cfg(target_os = "linux")]
118            downstream_socket: downstream,
119        };
120
121        // Spawn reverse proxy task (B -> A)
122        #[cfg(target_os = "linux")]
123        if let (Some(up), Some(down)) = (&session.upstream_socket, &session.downstream_socket) {
124            let up_clone = up.clone();
125            let down_clone = down.clone();
126            let last_activity = session.last_activity.clone();
127            let bytes_transferred = session.bytes_transferred.clone();
128            
129            tokio::spawn(async move {
130                let mut buf = [0u8; 65535];
131                loop {
132                    // Read from upstream (response from Server B)
133                    match up_clone.recv(&mut buf).await {
134                        Ok(n) => {
135                            // Update activity
136                            if let Ok(mut last) = last_activity.try_write() {
137                                *last = Instant::now();
138                            }
139                            bytes_transferred.fetch_add(n, Ordering::Relaxed);
140                            
141                            // Send to downstream (to Client A)
142                            if let Err(e) = down_clone.send(&buf[..n]).await {
143                                tracing::debug!("UDP downstream send error: {}", e);
144                                break;
145                            }
146                        }
147                        Err(e) => {
148                            tracing::debug!("UDP upstream recv error: {}", e);
149                            break;
150                        }
151                    }
152                }
153            });
154        }
155
156        sessions.insert(key, session.clone());
157        Ok((session, true))
158    }
159
160    /// Clean up idle sessions
161    pub async fn cleanup_idle_sessions(&self) -> Vec<Uuid> {
162        let mut sessions = self.sessions.write().await;
163        let now = Instant::now();
164        let mut removed_ids = Vec::new();
165        let mut keys_to_remove = Vec::new();
166
167        // Identify idle sessions
168        for (key, session) in sessions.iter() {
169            let last = *session.last_activity.read().await;
170            if now.duration_since(last) > self.idle_timeout {
171                removed_ids.push(session.flow_id);
172                keys_to_remove.push(key.clone());
173            }
174        }
175
176        // Remove them
177        for key in keys_to_remove {
178            sessions.remove(&key);
179        }
180        
181        removed_ids
182    }
183}
184
185/// UDP Proxy capable of handling multiple sessions
186pub struct UdpProxy {
187    socket: Arc<UdpSocket>,
188    #[allow(dead_code)]
189    session_manager: Arc<UdpSessionManager>,
190}
191
192impl UdpProxy {
193    pub fn new(socket: UdpSocket, idle_timeout: Duration) -> Self {
194        Self {
195            socket: Arc::new(socket),
196            session_manager: Arc::new(UdpSessionManager::new(idle_timeout)),
197        }
198    }
199
200    /// Run the proxy loop
201    pub async fn run(&self, on_flow: Sender<FlowUpdate>) -> crate::error::Result<()> 
202    {
203        let mut buf = [0u8; 65535];
204        
205        #[cfg(target_os = "linux")]
206        {
207            // Enable TPROXY on socket
208            LinuxTproxy::enable_tproxy(&self.socket)?;
209            
210            loop {
211                // Use recv_original_dst
212                let (len, src_addr, orig_dst) = match LinuxTproxy::recv_original_dst(&self.socket, &mut buf).await {
213                    Ok(res) => res,
214                    Err(e) => {
215                        tracing::error!("UDP TPROXY recv error: {}", e);
216                        continue;
217                    }
218                };
219                
220                if let Some(dst_addr) = orig_dst {
221                     match self.session_manager.get_or_create_session(src_addr, dst_addr).await {
222                         Ok((session, is_new)) => {
223                             if is_new {
224                                 // Create initial flow
225                                 let flow = Flow {
226                                     id: session.flow_id,
227                                     start_time: Utc::now(),
228                                     end_time: None,
229                                     network: NetworkInfo {
230                                         client_ip: src_addr.ip().to_string(),
231                                         client_port: src_addr.port(),
232                                         server_ip: dst_addr.ip().to_string(),
233                                         server_port: dst_addr.port(),
234                                         protocol: TransportProtocol::UDP,
235                                         tls: false,
236                                         tls_version: None,
237                                         sni: None,
238                                     },
239                                     layer: Layer::Udp(UdpLayer {
240                                         payload_size: len,
241                                         packet_count: 1,
242                                     }),
243                                     tags: vec![],
244                                     meta: HashMap::new(),
245                                 };
246                                 if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
247                                     crate::metrics::inc_flows_dropped();
248                                 }
249                             }
250                             
251                             // Forward packet logic (A -> B)
252                             // Using upstream socket bound to src_addr
253                             if let Some(upstream) = &session.upstream_socket {
254                                 if let Err(e) = upstream.send(&buf[..len]).await {
255                                     tracing::debug!("UDP upstream send error: {}", e);
256                                 } else {
257                                     session.bytes_transferred.fetch_add(len, Ordering::Relaxed);
258                                 }
259                             }
260                         }
261                         Err(e) => {
262                             tracing::warn!("Failed to create UDP session: {}", e);
263                         }
264                     }
265                }
266            }
267        }
268        
269        #[cfg(not(target_os = "linux"))]
270        {
271             let _ = on_flow;
272             loop {
273                let (_len, _src_addr) = self.socket.recv_from(&mut buf).await?;
274                // Without TPROXY, we don't know the original destination easily
275                // Just consume packets to avoid buffer bloat
276             }
277        }
278    }
279}