use super::model::ResolvedTarget;
use crate::common::config::env_loader;
use crate::ingress::tasks::ConnectionGuard;
use dashmap::DashMap;
use fancy_log::{LogLevel, log};
use once_cell::sync::Lazy;
use std::mem;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::UdpSocket;
use tokio::time::{Duration, Instant};
pub struct Session {
pub target: ResolvedTarget,
pub upstream_socket: Arc<UdpSocket>,
pub last_seen: Instant,
pub _guard: ConnectionGuard,
}
pub static SESSIONS: Lazy<DashMap<(SocketAddr, String), Arc<Session>>> = Lazy::new(DashMap::new);
pub static REVERSE_SESSIONS: Lazy<DashMap<SocketAddr, SocketAddr>> = Lazy::new(DashMap::new);
pub fn start_session_cleanup_task() {
log(LogLevel::Debug, "⚙ Starting UDP session cleanup task...");
let buffer_limit_str = env_loader::get_env("UDP_SESSION_BUFFER", "4194304".to_owned());
let buffer_limit = buffer_limit_str.parse::<usize>().unwrap_or(4_194_304);
tokio::spawn(async move {
let session_timeout_secs = env_loader::get_env("UDP_SESSION_TIMEOUT_SECS", "30".to_owned())
.parse::<u64>()
.unwrap_or(30);
let session_timeout = Duration::from_secs(session_timeout_secs);
let mut interval = tokio::time::interval(Duration::from_secs(10));
loop {
interval.tick().await;
let now = Instant::now();
let mut expired_keys = Vec::new();
for entry in SESSIONS.iter() {
if now.duration_since(entry.value().last_seen) > session_timeout {
expired_keys.push(entry.key().clone());
}
}
for key in expired_keys {
if let Some((_, session)) = SESSIONS.remove(&key)
&& let Ok(addr) = session.upstream_socket.local_addr()
{
REVERSE_SESSIONS.remove(&addr);
}
}
let current_size =
SESSIONS.len() * (mem::size_of::<(SocketAddr, String)>() + mem::size_of::<Arc<Session>>());
if current_size > buffer_limit {
log(
LogLevel::Warn,
&format!(
"⚠ UDP session buffer limit exceeded ({current_size} > {buffer_limit}). Pruning oldest sessions."
),
);
let mut all_sessions: Vec<_> = SESSIONS
.iter()
.map(|e| (e.key().clone(), e.value().last_seen))
.collect();
all_sessions.sort_by_key(|a| a.1);
let to_prune_count = (SESSIONS.len() as f64 * 0.1).ceil() as usize;
for (key, _) in all_sessions.iter().take(to_prune_count) {
if let Some((_, session)) = SESSIONS.remove(key)
&& let Ok(addr) = session.upstream_socket.local_addr()
{
REVERSE_SESSIONS.remove(&addr);
}
}
}
}
});
}