Skip to main content

zlayer_proxy/stream/
udp.rs

1//! UDP stream proxy service
2//!
3//! Custom UDP proxy implementation with session tracking.
4//! Each client gets a dedicated session that maps to a backend.
5
6use dashmap::DashMap;
7use std::net::SocketAddr;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::net::UdpSocket;
11use tokio::sync::mpsc;
12
13use super::config::DEFAULT_UDP_SESSION_TIMEOUT;
14use super::registry::StreamRegistry;
15
16/// UDP session state
17struct UdpSession {
18    /// Backend address for this session
19    backend: SocketAddr,
20    /// Socket to communicate with backend (bound to ephemeral port)
21    backend_socket: Arc<UdpSocket>,
22    /// Last activity timestamp
23    last_activity: Instant,
24}
25
26/// UDP stream proxy service
27///
28/// Listens on a port and proxies UDP datagrams to registered backends.
29/// Maintains session state to route responses back to the correct client.
30pub struct UdpStreamService {
31    registry: Arc<StreamRegistry>,
32    listen_port: u16,
33    session_timeout: Duration,
34}
35
36impl UdpStreamService {
37    /// Create a new UDP stream service
38    #[must_use]
39    pub fn new(
40        registry: Arc<StreamRegistry>,
41        listen_port: u16,
42        session_timeout: Option<Duration>,
43    ) -> Self {
44        Self {
45            registry,
46            listen_port,
47            session_timeout: session_timeout.unwrap_or(DEFAULT_UDP_SESSION_TIMEOUT),
48        }
49    }
50
51    /// Get the listen port
52    #[must_use]
53    pub fn port(&self) -> u16 {
54        self.listen_port
55    }
56
57    /// Get the session timeout
58    #[must_use]
59    pub fn session_timeout(&self) -> Duration {
60        self.session_timeout
61    }
62
63    /// Get a reference to the registry
64    #[must_use]
65    pub fn registry(&self) -> &Arc<StreamRegistry> {
66        &self.registry
67    }
68
69    /// Run the UDP proxy service by binding its own socket.
70    ///
71    /// This method runs indefinitely, proxying UDP datagrams between
72    /// clients and backends. Each client address gets its own session.
73    ///
74    /// # Errors
75    ///
76    /// Returns an error if binding to the listen port fails or if the
77    /// main receive loop encounters a fatal IO error.
78    pub async fn run(self: Arc<Self>) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
79        // Bind to listen port
80        let listen_addr = format!("0.0.0.0:{}", self.listen_port);
81        let socket = UdpSocket::bind(&listen_addr).await?;
82
83        tracing::info!(port = self.listen_port, "UDP stream proxy listening");
84
85        self.serve(socket).await
86    }
87
88    /// Run the UDP proxy service on an externally-provided socket.
89    ///
90    /// This is the non-self-binding entry point, used by `ProxyManager` to serve
91    /// UDP endpoints when the caller has already bound the socket.
92    ///
93    /// Runs indefinitely, proxying UDP datagrams between clients and backends.
94    /// Each client address gets its own session with a dedicated backend socket.
95    ///
96    /// # Errors
97    ///
98    /// Returns an error if the main receive loop encounters a fatal IO error
99    /// or if creating a backend session socket fails.
100    #[allow(clippy::too_many_lines)]
101    pub async fn serve(
102        self: Arc<Self>,
103        socket: UdpSocket,
104    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
105        let socket = Arc::new(socket);
106
107        tracing::info!(
108            port = self.listen_port,
109            "UDP stream proxy serving (standalone)"
110        );
111
112        // Session tracking: client_addr -> session
113        let sessions: Arc<DashMap<SocketAddr, UdpSession>> = Arc::new(DashMap::new());
114
115        // Channel for backend responses to be sent back to clients
116        let (response_tx, mut response_rx) = mpsc::channel::<(Vec<u8>, SocketAddr)>(4096);
117
118        // Spawn response sender task
119        let socket_for_responses = socket.clone();
120        tokio::spawn(async move {
121            while let Some((data, client_addr)) = response_rx.recv().await {
122                if let Err(e) = socket_for_responses.send_to(&data, client_addr).await {
123                    tracing::debug!(
124                        error = %e,
125                        client = %client_addr,
126                        "Failed to send UDP response to client"
127                    );
128                }
129            }
130        });
131
132        // Spawn session cleanup task
133        let sessions_for_cleanup = sessions.clone();
134        let timeout = self.session_timeout;
135        tokio::spawn(async move {
136            let mut interval = tokio::time::interval(Duration::from_secs(10));
137            loop {
138                interval.tick().await;
139                let now = Instant::now();
140                let before = sessions_for_cleanup.len();
141                sessions_for_cleanup
142                    .retain(|_, session| now.duration_since(session.last_activity) < timeout);
143                let after = sessions_for_cleanup.len();
144                if before != after {
145                    tracing::debug!(
146                        removed = before - after,
147                        remaining = after,
148                        "Cleaned up expired UDP sessions"
149                    );
150                }
151            }
152        });
153
154        // Main receive loop
155        let mut buf = vec![0u8; 65535];
156        loop {
157            let (len, client_addr) = socket.recv_from(&mut buf).await?;
158            let data = buf[..len].to_vec();
159
160            // Get or create session for this client
161            let session_backend = if let Some(mut existing) = sessions.get_mut(&client_addr) {
162                existing.last_activity = Instant::now();
163                existing.backend
164            } else {
165                // Create new session
166                let Some(service) = self.registry.resolve_udp(self.listen_port) else {
167                    tracing::warn!(
168                        port = self.listen_port,
169                        client = %client_addr,
170                        "No service registered for UDP port"
171                    );
172                    continue;
173                };
174
175                let Some(backend) = service.select_backend() else {
176                    tracing::warn!(
177                        port = self.listen_port,
178                        service = %service.name,
179                        client = %client_addr,
180                        "No backends available for UDP service"
181                    );
182                    continue;
183                };
184
185                // Create dedicated socket for this session's backend communication
186                let backend_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await?);
187                backend_socket.connect(&backend).await?;
188
189                tracing::debug!(
190                    port = self.listen_port,
191                    service = %service.name,
192                    client = %client_addr,
193                    backend = %backend,
194                    "Created new UDP session"
195                );
196
197                // Spawn task to receive responses from backend
198                let backend_socket_recv = backend_socket.clone();
199                let response_tx = response_tx.clone();
200                let client = client_addr;
201                let sessions_ref = sessions.clone();
202                tokio::spawn(async move {
203                    let mut buf = vec![0u8; 65535];
204                    loop {
205                        match backend_socket_recv.recv(&mut buf).await {
206                            Ok(len) => {
207                                // Update session activity
208                                if let Some(mut s) = sessions_ref.get_mut(&client) {
209                                    s.last_activity = Instant::now();
210                                }
211                                // Send response back to client
212                                if response_tx
213                                    .send((buf[..len].to_vec(), client))
214                                    .await
215                                    .is_err()
216                                {
217                                    break; // Channel closed
218                                }
219                            }
220                            Err(e) => {
221                                tracing::debug!(
222                                    error = %e,
223                                    client = %client,
224                                    "Backend socket receive error"
225                                );
226                                break;
227                            }
228                        }
229                    }
230                });
231
232                let session = UdpSession {
233                    backend,
234                    backend_socket,
235                    last_activity: Instant::now(),
236                };
237                sessions.insert(client_addr, session);
238                backend
239            };
240
241            // Forward packet to backend
242            if let Some(s) = sessions.get(&client_addr) {
243                if let Err(e) = s.backend_socket.send(&data).await {
244                    tracing::debug!(
245                        error = %e,
246                        client = %client_addr,
247                        backend = %session_backend,
248                        "Failed to forward UDP packet to backend"
249                    );
250                }
251            }
252        }
253    }
254}