use dashmap::DashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::net::UdpSocket;
use tokio::sync::mpsc;
use super::config::DEFAULT_UDP_SESSION_TIMEOUT;
use super::registry::StreamRegistry;
struct UdpSession {
backend: SocketAddr,
backend_socket: Arc<UdpSocket>,
last_activity: Instant,
}
pub struct UdpStreamService {
registry: Arc<StreamRegistry>,
listen_port: u16,
session_timeout: Duration,
}
impl UdpStreamService {
#[must_use]
pub fn new(
registry: Arc<StreamRegistry>,
listen_port: u16,
session_timeout: Option<Duration>,
) -> Self {
Self {
registry,
listen_port,
session_timeout: session_timeout.unwrap_or(DEFAULT_UDP_SESSION_TIMEOUT),
}
}
#[must_use]
pub fn port(&self) -> u16 {
self.listen_port
}
#[must_use]
pub fn session_timeout(&self) -> Duration {
self.session_timeout
}
#[must_use]
pub fn registry(&self) -> &Arc<StreamRegistry> {
&self.registry
}
pub async fn run(self: Arc<Self>) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let listen_addr = format!("0.0.0.0:{}", self.listen_port);
let socket = UdpSocket::bind(&listen_addr).await?;
tracing::info!(port = self.listen_port, "UDP stream proxy listening");
self.serve(socket).await
}
#[allow(clippy::too_many_lines)]
pub async fn serve(
self: Arc<Self>,
socket: UdpSocket,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let socket = Arc::new(socket);
tracing::info!(
port = self.listen_port,
"UDP stream proxy serving (standalone)"
);
let sessions: Arc<DashMap<SocketAddr, UdpSession>> = Arc::new(DashMap::new());
let (response_tx, mut response_rx) = mpsc::channel::<(Vec<u8>, SocketAddr)>(4096);
let socket_for_responses = socket.clone();
tokio::spawn(async move {
while let Some((data, client_addr)) = response_rx.recv().await {
if let Err(e) = socket_for_responses.send_to(&data, client_addr).await {
tracing::debug!(
error = %e,
client = %client_addr,
"Failed to send UDP response to client"
);
}
}
});
let sessions_for_cleanup = sessions.clone();
let timeout = self.session_timeout;
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(10));
loop {
interval.tick().await;
let now = Instant::now();
let before = sessions_for_cleanup.len();
sessions_for_cleanup
.retain(|_, session| now.duration_since(session.last_activity) < timeout);
let after = sessions_for_cleanup.len();
if before != after {
tracing::debug!(
removed = before - after,
remaining = after,
"Cleaned up expired UDP sessions"
);
}
}
});
let mut buf = vec![0u8; 65535];
loop {
let (len, client_addr) = socket.recv_from(&mut buf).await?;
let data = buf[..len].to_vec();
let session_backend = if let Some(mut existing) = sessions.get_mut(&client_addr) {
existing.last_activity = Instant::now();
existing.backend
} else {
let Some(service) = self.registry.resolve_udp(self.listen_port) else {
tracing::warn!(
port = self.listen_port,
client = %client_addr,
"No service registered for UDP port"
);
continue;
};
let Some(backend) = service.select_backend() else {
tracing::warn!(
port = self.listen_port,
service = %service.name,
client = %client_addr,
"No backends available for UDP service"
);
continue;
};
let backend_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await?);
backend_socket.connect(&backend).await?;
tracing::debug!(
port = self.listen_port,
service = %service.name,
client = %client_addr,
backend = %backend,
"Created new UDP session"
);
let backend_socket_recv = backend_socket.clone();
let response_tx = response_tx.clone();
let client = client_addr;
let sessions_ref = sessions.clone();
tokio::spawn(async move {
let mut buf = vec![0u8; 65535];
loop {
match backend_socket_recv.recv(&mut buf).await {
Ok(len) => {
if let Some(mut s) = sessions_ref.get_mut(&client) {
s.last_activity = Instant::now();
}
if response_tx
.send((buf[..len].to_vec(), client))
.await
.is_err()
{
break; }
}
Err(e) => {
tracing::debug!(
error = %e,
client = %client,
"Backend socket receive error"
);
break;
}
}
}
});
let session = UdpSession {
backend,
backend_socket,
last_activity: Instant::now(),
};
sessions.insert(client_addr, session);
backend
};
if let Some(s) = sessions.get(&client_addr) {
if let Err(e) = s.backend_socket.send(&data).await {
tracing::debug!(
error = %e,
client = %client_addr,
backend = %session_backend,
"Failed to forward UDP packet to backend"
);
}
}
}
}
}