1#[allow(unused_imports)]
2use chrono::Utc;
3#[allow(unused_imports)]
4use relay_core_api::flow::{Flow, FlowUpdate, Layer, NetworkInfo, TransportProtocol, UdpLayer};
5use std::collections::HashMap;
6use std::net::{IpAddr, SocketAddr};
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::net::UdpSocket;
10use tokio::sync::RwLock;
11use tokio::sync::mpsc::Sender;
12use uuid::Uuid;
13
14#[cfg(target_os = "linux")]
15use crate::capture::linux_tproxy::LinuxTproxy;
16
17use std::sync::atomic::{AtomicUsize, Ordering};
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
21pub struct UdpSessionKey {
22 pub src_ip: IpAddr,
23 pub src_port: u16,
24 pub dst_ip: IpAddr,
25 pub dst_port: u16,
26 }
28
29impl UdpSessionKey {
30 pub fn new(src: SocketAddr, dst: SocketAddr) -> Self {
31 Self {
32 src_ip: src.ip(),
33 src_port: src.port(),
34 dst_ip: dst.ip(),
35 dst_port: dst.port(),
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct UdpSession {
43 pub flow_id: Uuid,
44 pub key: UdpSessionKey,
45 pub created_at: Instant,
46 pub last_activity: Arc<RwLock<Instant>>,
47 pub packet_count: Arc<AtomicUsize>,
48 pub bytes_transferred: Arc<AtomicUsize>,
49 #[cfg(target_os = "linux")]
50 pub upstream_socket: Option<Arc<UdpSocket>>, #[cfg(target_os = "linux")]
52 pub downstream_socket: Option<Arc<UdpSocket>>, }
54
55pub struct UdpSessionManager {
57 sessions: RwLock<HashMap<UdpSessionKey, UdpSession>>,
58 idle_timeout: Duration,
59}
60
61impl UdpSessionManager {
62 pub fn new(idle_timeout: Duration) -> Self {
63 Self {
64 sessions: RwLock::new(HashMap::new()),
65 idle_timeout,
66 }
67 }
68
69 pub async fn get_or_create_session(
72 &self,
73 src: SocketAddr,
74 dst: SocketAddr,
75 ) -> std::io::Result<(UdpSession, bool)> {
76 let key = UdpSessionKey::new(src, dst);
77 {
79 let sessions = self.sessions.read().await;
80 if let Some(session) = sessions.get(&key) {
81 let mut last = session.last_activity.write().await;
82 *last = Instant::now();
83 session.packet_count.fetch_add(1, Ordering::Relaxed);
84 return Ok((session.clone(), false));
85 }
86 }
87
88 let mut sessions = self.sessions.write().await;
90 if let Some(session) = sessions.get(&key) {
92 let mut last = session.last_activity.write().await;
93 *last = Instant::now();
94 session.packet_count.fetch_add(1, Ordering::Relaxed);
95 return Ok((session.clone(), false));
96 }
97
98 #[cfg(target_os = "linux")]
99 let (upstream, downstream) = {
100 let up = LinuxTproxy::create_transparent_udp_socket(src)?;
102 up.connect(dst).await?;
103
104 let down = LinuxTproxy::create_transparent_udp_socket(dst)?;
106 down.connect(src).await?;
107
108 (Some(Arc::new(up)), Some(Arc::new(down)))
109 };
110
111 let session = UdpSession {
113 flow_id: Uuid::new_v4(),
114 key: key.clone(),
115 created_at: Instant::now(),
116 last_activity: Arc::new(RwLock::new(Instant::now())),
117 packet_count: Arc::new(AtomicUsize::new(1)),
118 bytes_transferred: Arc::new(AtomicUsize::new(0)),
119 #[cfg(target_os = "linux")]
120 upstream_socket: upstream,
121 #[cfg(target_os = "linux")]
122 downstream_socket: downstream,
123 };
124
125 #[cfg(target_os = "linux")]
127 if let (Some(up), Some(down)) = (&session.upstream_socket, &session.downstream_socket) {
128 let up_clone = up.clone();
129 let down_clone = down.clone();
130 let last_activity = session.last_activity.clone();
131 let bytes_transferred = session.bytes_transferred.clone();
132
133 tokio::spawn(async move {
134 let mut buf = [0u8; 65535];
135 loop {
136 match up_clone.recv(&mut buf).await {
138 Ok(n) => {
139 if let Ok(mut last) = last_activity.try_write() {
141 *last = Instant::now();
142 }
143 bytes_transferred.fetch_add(n, Ordering::Relaxed);
144
145 if let Err(e) = down_clone.send(&buf[..n]).await {
147 tracing::debug!("UDP downstream send error: {}", e);
148 break;
149 }
150 }
151 Err(e) => {
152 tracing::debug!("UDP upstream recv error: {}", e);
153 break;
154 }
155 }
156 }
157 });
158 }
159
160 sessions.insert(key, session.clone());
161 Ok((session, true))
162 }
163
164 pub async fn cleanup_idle_sessions(&self) -> Vec<Uuid> {
166 let mut sessions = self.sessions.write().await;
167 let now = Instant::now();
168 let mut removed_ids = Vec::new();
169 let mut keys_to_remove = Vec::new();
170
171 for (key, session) in sessions.iter() {
173 let last = *session.last_activity.read().await;
174 if now.duration_since(last) > self.idle_timeout {
175 removed_ids.push(session.flow_id);
176 keys_to_remove.push(key.clone());
177 }
178 }
179
180 for key in keys_to_remove {
182 sessions.remove(&key);
183 }
184
185 removed_ids
186 }
187}
188
189pub struct UdpProxy {
191 socket: Arc<UdpSocket>,
192 #[allow(dead_code)]
193 session_manager: Arc<UdpSessionManager>,
194}
195
196impl UdpProxy {
197 pub fn new(socket: UdpSocket, idle_timeout: Duration) -> Self {
198 Self {
199 socket: Arc::new(socket),
200 session_manager: Arc::new(UdpSessionManager::new(idle_timeout)),
201 }
202 }
203
204 pub async fn run(&self, on_flow: Sender<FlowUpdate>) -> crate::error::Result<()> {
206 let mut buf = [0u8; 65535];
207
208 #[cfg(target_os = "linux")]
209 {
210 LinuxTproxy::enable_tproxy(&self.socket)?;
212
213 loop {
214 let (len, src_addr, orig_dst) =
216 match LinuxTproxy::recv_original_dst(&self.socket, &mut buf).await {
217 Ok(res) => res,
218 Err(e) => {
219 tracing::error!("UDP TPROXY recv error: {}", e);
220 continue;
221 }
222 };
223
224 if let Some(dst_addr) = orig_dst {
225 match self
226 .session_manager
227 .get_or_create_session(src_addr, dst_addr)
228 .await
229 {
230 Ok((session, is_new)) => {
231 if is_new {
232 let flow = Flow {
234 id: session.flow_id,
235 start_time: Utc::now(),
236 end_time: None,
237 network: NetworkInfo {
238 client_ip: src_addr.ip().to_string(),
239 client_port: src_addr.port(),
240 server_ip: dst_addr.ip().to_string(),
241 server_port: dst_addr.port(),
242 protocol: TransportProtocol::UDP,
243 tls: false,
244 tls_version: None,
245 sni: None,
246 },
247 layer: Layer::Udp(UdpLayer {
248 payload_size: len,
249 packet_count: 1,
250 }),
251 tags: vec![],
252 meta: HashMap::new(),
253 };
254 if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
255 crate::metrics::inc_flows_dropped();
256 }
257 }
258
259 if let Some(upstream) = &session.upstream_socket {
262 if let Err(e) = upstream.send(&buf[..len]).await {
263 tracing::debug!("UDP upstream send error: {}", e);
264 } else {
265 session.bytes_transferred.fetch_add(len, Ordering::Relaxed);
266 }
267 }
268 }
269 Err(e) => {
270 tracing::warn!("Failed to create UDP session: {}", e);
271 }
272 }
273 }
274 }
275 }
276
277 #[cfg(not(target_os = "linux"))]
278 {
279 let _ = on_flow;
280 loop {
281 let (_len, _src_addr) = self.socket.recv_from(&mut buf).await?;
282 }
285 }
286 }
287}