Skip to main content

relay_core_lib/capture/
udp.rs

1#[allow(unused_imports)]
2use chrono::Utc;
3#[allow(unused_imports)]
4use relay_core_api::flow::{Flow, FlowUpdate, Layer, NetworkInfo, TransportProtocol, UdpLayer};
5use std::collections::HashMap;
6use std::net::{IpAddr, SocketAddr};
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::net::UdpSocket;
10use tokio::sync::RwLock;
11use tokio::sync::mpsc::Sender;
12use uuid::Uuid;
13
14#[cfg(target_os = "linux")]
15use crate::capture::linux_tproxy::LinuxTproxy;
16
17use std::sync::atomic::{AtomicUsize, Ordering};
18
19/// Key for UDP session (5-tuple)
20#[derive(Debug, Clone, PartialEq, Eq, Hash)]
21pub struct UdpSessionKey {
22    pub src_ip: IpAddr,
23    pub src_port: u16,
24    pub dst_ip: IpAddr,
25    pub dst_port: u16,
26    // Protocol is implicitly UDP
27}
28
29impl UdpSessionKey {
30    pub fn new(src: SocketAddr, dst: SocketAddr) -> Self {
31        Self {
32            src_ip: src.ip(),
33            src_port: src.port(),
34            dst_ip: dst.ip(),
35            dst_port: dst.port(),
36        }
37    }
38}
39
40/// UDP Session Metadata
41#[derive(Debug, Clone)]
42pub struct UdpSession {
43    pub flow_id: Uuid,
44    pub key: UdpSessionKey,
45    pub created_at: Instant,
46    pub last_activity: Arc<RwLock<Instant>>,
47    pub packet_count: Arc<AtomicUsize>,
48    pub bytes_transferred: Arc<AtomicUsize>,
49    #[cfg(target_os = "linux")]
50    pub upstream_socket: Option<Arc<UdpSocket>>, // Bound to src, connected to dst
51    #[cfg(target_os = "linux")]
52    pub downstream_socket: Option<Arc<UdpSocket>>, // Bound to dst, connected to src
53}
54
55/// Manager for tracking active UDP sessions
56pub struct UdpSessionManager {
57    sessions: RwLock<HashMap<UdpSessionKey, UdpSession>>,
58    idle_timeout: Duration,
59}
60
61impl UdpSessionManager {
62    pub fn new(idle_timeout: Duration) -> Self {
63        Self {
64            sessions: RwLock::new(HashMap::new()),
65            idle_timeout,
66        }
67    }
68
69    /// Get existing session or create new one
70    /// Returns (session, is_new)
71    pub async fn get_or_create_session(
72        &self,
73        src: SocketAddr,
74        dst: SocketAddr,
75    ) -> std::io::Result<(UdpSession, bool)> {
76        let key = UdpSessionKey::new(src, dst);
77        // Fast path: read lock
78        {
79            let sessions = self.sessions.read().await;
80            if let Some(session) = sessions.get(&key) {
81                let mut last = session.last_activity.write().await;
82                *last = Instant::now();
83                session.packet_count.fetch_add(1, Ordering::Relaxed);
84                return Ok((session.clone(), false));
85            }
86        }
87
88        // Slow path: write lock
89        let mut sessions = self.sessions.write().await;
90        // Check again
91        if let Some(session) = sessions.get(&key) {
92            let mut last = session.last_activity.write().await;
93            *last = Instant::now();
94            session.packet_count.fetch_add(1, Ordering::Relaxed);
95            return Ok((session.clone(), false));
96        }
97
98        #[cfg(target_os = "linux")]
99        let (upstream, downstream) = {
100            // Create upstream socket: Bound to src, connect to dst
101            let up = LinuxTproxy::create_transparent_udp_socket(src)?;
102            up.connect(dst).await?;
103
104            // Create downstream socket: Bound to dst, connect to src
105            let down = LinuxTproxy::create_transparent_udp_socket(dst)?;
106            down.connect(src).await?;
107
108            (Some(Arc::new(up)), Some(Arc::new(down)))
109        };
110
111        // Create new session
112        let session = UdpSession {
113            flow_id: Uuid::new_v4(),
114            key: key.clone(),
115            created_at: Instant::now(),
116            last_activity: Arc::new(RwLock::new(Instant::now())),
117            packet_count: Arc::new(AtomicUsize::new(1)),
118            bytes_transferred: Arc::new(AtomicUsize::new(0)),
119            #[cfg(target_os = "linux")]
120            upstream_socket: upstream,
121            #[cfg(target_os = "linux")]
122            downstream_socket: downstream,
123        };
124
125        // Spawn reverse proxy task (B -> A)
126        #[cfg(target_os = "linux")]
127        if let (Some(up), Some(down)) = (&session.upstream_socket, &session.downstream_socket) {
128            let up_clone = up.clone();
129            let down_clone = down.clone();
130            let last_activity = session.last_activity.clone();
131            let bytes_transferred = session.bytes_transferred.clone();
132
133            tokio::spawn(async move {
134                let mut buf = [0u8; 65535];
135                loop {
136                    // Read from upstream (response from Server B)
137                    match up_clone.recv(&mut buf).await {
138                        Ok(n) => {
139                            // Update activity
140                            if let Ok(mut last) = last_activity.try_write() {
141                                *last = Instant::now();
142                            }
143                            bytes_transferred.fetch_add(n, Ordering::Relaxed);
144
145                            // Send to downstream (to Client A)
146                            if let Err(e) = down_clone.send(&buf[..n]).await {
147                                tracing::debug!("UDP downstream send error: {}", e);
148                                break;
149                            }
150                        }
151                        Err(e) => {
152                            tracing::debug!("UDP upstream recv error: {}", e);
153                            break;
154                        }
155                    }
156                }
157            });
158        }
159
160        sessions.insert(key, session.clone());
161        Ok((session, true))
162    }
163
164    /// Clean up idle sessions
165    pub async fn cleanup_idle_sessions(&self) -> Vec<Uuid> {
166        let mut sessions = self.sessions.write().await;
167        let now = Instant::now();
168        let mut removed_ids = Vec::new();
169        let mut keys_to_remove = Vec::new();
170
171        // Identify idle sessions
172        for (key, session) in sessions.iter() {
173            let last = *session.last_activity.read().await;
174            if now.duration_since(last) > self.idle_timeout {
175                removed_ids.push(session.flow_id);
176                keys_to_remove.push(key.clone());
177            }
178        }
179
180        // Remove them
181        for key in keys_to_remove {
182            sessions.remove(&key);
183        }
184
185        removed_ids
186    }
187}
188
189/// UDP Proxy capable of handling multiple sessions
190pub struct UdpProxy {
191    socket: Arc<UdpSocket>,
192    #[allow(dead_code)]
193    session_manager: Arc<UdpSessionManager>,
194}
195
196impl UdpProxy {
197    pub fn new(socket: UdpSocket, idle_timeout: Duration) -> Self {
198        Self {
199            socket: Arc::new(socket),
200            session_manager: Arc::new(UdpSessionManager::new(idle_timeout)),
201        }
202    }
203
204    /// Run the proxy loop
205    pub async fn run(&self, on_flow: Sender<FlowUpdate>) -> crate::error::Result<()> {
206        let mut buf = [0u8; 65535];
207
208        #[cfg(target_os = "linux")]
209        {
210            // Enable TPROXY on socket
211            LinuxTproxy::enable_tproxy(&self.socket)?;
212
213            loop {
214                // Use recv_original_dst
215                let (len, src_addr, orig_dst) =
216                    match LinuxTproxy::recv_original_dst(&self.socket, &mut buf).await {
217                        Ok(res) => res,
218                        Err(e) => {
219                            tracing::error!("UDP TPROXY recv error: {}", e);
220                            continue;
221                        }
222                    };
223
224                if let Some(dst_addr) = orig_dst {
225                    match self
226                        .session_manager
227                        .get_or_create_session(src_addr, dst_addr)
228                        .await
229                    {
230                        Ok((session, is_new)) => {
231                            if is_new {
232                                // Create initial flow
233                                let flow = Flow {
234                                    id: session.flow_id,
235                                    start_time: Utc::now(),
236                                    end_time: None,
237                                    network: NetworkInfo {
238                                        client_ip: src_addr.ip().to_string(),
239                                        client_port: src_addr.port(),
240                                        server_ip: dst_addr.ip().to_string(),
241                                        server_port: dst_addr.port(),
242                                        protocol: TransportProtocol::UDP,
243                                        tls: false,
244                                        tls_version: None,
245                                        sni: None,
246                                    },
247                                    layer: Layer::Udp(UdpLayer {
248                                        payload_size: len,
249                                        packet_count: 1,
250                                    }),
251                                    tags: vec![],
252                                    meta: HashMap::new(),
253                                };
254                                if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
255                                    crate::metrics::inc_flows_dropped();
256                                }
257                            }
258
259                            // Forward packet logic (A -> B)
260                            // Using upstream socket bound to src_addr
261                            if let Some(upstream) = &session.upstream_socket {
262                                if let Err(e) = upstream.send(&buf[..len]).await {
263                                    tracing::debug!("UDP upstream send error: {}", e);
264                                } else {
265                                    session.bytes_transferred.fetch_add(len, Ordering::Relaxed);
266                                }
267                            }
268                        }
269                        Err(e) => {
270                            tracing::warn!("Failed to create UDP session: {}", e);
271                        }
272                    }
273                }
274            }
275        }
276
277        #[cfg(not(target_os = "linux"))]
278        {
279            let _ = on_flow;
280            loop {
281                let (_len, _src_addr) = self.socket.recv_from(&mut buf).await?;
282                // Without TPROXY, we don't know the original destination easily
283                // Just consume packets to avoid buffer bloat
284            }
285        }
286    }
287}