use crate::common::config::env_loader;
use crate::ingress::tasks::ConnectionGuard;
use dashmap::DashMap;
use fancy_log::{LogLevel, log};
use once_cell::sync::Lazy;
use std::collections::BTreeMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Instant;
use tokio::net::UdpSocket;
static GLOBAL_PENDING_BYTES: AtomicUsize = AtomicUsize::new(0);
fn get_global_byte_limit() -> usize {
env_loader::get_env("QUIC_GLOBAL_PENDING_BYTES_LIMIT", "67108864".to_owned())
.parse()
.unwrap_or(67_108_864)
}
fn get_session_byte_limit() -> usize {
env_loader::get_env("QUIC_SESSION_BUFFER_LIMIT", "65536".to_owned())
.parse()
.unwrap_or(65_536)
}
pub fn try_reserve_global_bytes(amount: usize) -> bool {
let limit = get_global_byte_limit();
let current = GLOBAL_PENDING_BYTES.load(Ordering::Relaxed);
if current + amount > limit {
log(
LogLevel::Warn,
&format!(
"⚠ QUIC Global Buffer Limit Exceeded! Dropping {amount} bytes (Current: {current}/{limit})"
),
);
return false;
}
GLOBAL_PENDING_BYTES.fetch_add(amount, Ordering::Relaxed);
true
}
pub fn release_global_bytes(amount: usize) {
GLOBAL_PENDING_BYTES.fetch_sub(amount, Ordering::Relaxed);
}
#[derive(Debug, Clone)]
pub enum SessionAction {
Forward {
target_addr: SocketAddr,
upstream_socket: Arc<UdpSocket>,
last_seen: Instant,
_guard: ConnectionGuard,
},
Terminate {
muxer_port: u16,
last_seen: Instant,
_guard: Option<ConnectionGuard>,
},
}
#[derive(Debug)]
pub struct PendingState {
pub crypto_stream: BTreeMap<usize, Vec<u8>>,
pub queued_packets: Vec<(bytes::Bytes, SocketAddr, SocketAddr)>,
pub last_seen: Instant,
pub processing: bool,
pub _guard: ConnectionGuard,
pub total_bytes: usize,
}
impl PendingState {
pub fn drain_queue(&mut self) -> Vec<(bytes::Bytes, SocketAddr, SocketAddr)> {
let packets = std::mem::take(&mut self.queued_packets);
let drained_size: usize = packets.iter().map(|(data, _, _)| data.len()).sum();
self.total_bytes = self.total_bytes.saturating_sub(drained_size);
release_global_bytes(drained_size);
packets
}
}
impl Drop for PendingState {
fn drop(&mut self) {
if self.total_bytes > 0 {
release_global_bytes(self.total_bytes);
}
}
}
pub static CID_REGISTRY: Lazy<DashMap<Vec<u8>, SessionAction>> = Lazy::new(DashMap::new);
pub static PENDING_INITIALS: Lazy<DashMap<Vec<u8>, PendingState>> = Lazy::new(DashMap::new);
pub static IP_STICKY_MAP: Lazy<
DashMap<SocketAddr, (SocketAddr, Arc<UdpSocket>, Instant, ConnectionGuard)>,
> = Lazy::new(DashMap::new);
pub fn register_session(cid: Vec<u8>, action: SessionAction) {
PENDING_INITIALS.remove(&cid);
CID_REGISTRY.insert(cid, action);
}
pub fn register_sticky(
client: SocketAddr,
target: SocketAddr,
socket: Arc<UdpSocket>,
guard: ConnectionGuard,
) {
IP_STICKY_MAP.insert(client, (target, socket, Instant::now(), guard));
}
pub fn get_sticky(client: &SocketAddr) -> Option<(SocketAddr, Arc<UdpSocket>)> {
if let Some(mut entry) = IP_STICKY_MAP.get_mut(client) {
entry.2 = Instant::now(); return Some((entry.0, entry.1.clone()));
}
None
}
pub fn get_session(cid: &[u8]) -> Option<SessionAction> {
CID_REGISTRY.get(cid).map(|r| r.value().clone())
}
pub fn touch_session(cid: &[u8]) {
if let Some(mut entry) = CID_REGISTRY.get_mut(cid) {
match entry.value_mut() {
SessionAction::Forward { last_seen, .. } | SessionAction::Terminate { last_seen, .. } => {
*last_seen = Instant::now()
}
}
}
}
#[must_use]
pub fn check_session_limit(current: usize, add: usize) -> bool {
let limit = get_session_byte_limit();
if current + add > limit {
log(
LogLevel::Warn,
&format!("⚠ QUIC Session Buffer Limit Exceeded! Dropping (Current: {current}/{limit})"),
);
return false;
}
true
}
pub fn cleanup_sessions(timeout_secs: u64) {
let now = Instant::now();
CID_REGISTRY.retain(|_, action| {
let last = match action {
SessionAction::Forward { last_seen, .. } | SessionAction::Terminate { last_seen, .. } => {
last_seen
}
};
now.duration_since(*last).as_secs() < timeout_secs
});
PENDING_INITIALS.retain(|_, state| now.duration_since(state.last_seen).as_secs() < 10);
let sticky_timeout_str = env_loader::get_env("QUIC_STICKY_SESSION_TTL", "60".to_owned());
let sticky_timeout = sticky_timeout_str.parse::<u64>().unwrap_or(60);
IP_STICKY_MAP.retain(|_, (_, _, last, _)| now.duration_since(*last).as_secs() < sticky_timeout);
}
pub fn start_cleanup_task() {
use tokio::time::{Duration, sleep};
log(LogLevel::Debug, "⚙ Starting QUIC session cleanup task...");
tokio::spawn(async move {
let ttl_str = env_loader::get_env("QUIC_SESSION_TTL_SECS", "300".to_owned());
let ttl = ttl_str.parse::<u64>().unwrap_or(300);
let check_interval = Duration::from_secs(ttl / 2);
loop {
sleep(check_interval).await;
cleanup_sessions(ttl);
}
});
}