use super::session::{self, SessionAction};
use super::virtual_socket::{VirtualPacket, VirtualUdpSocket};
use crate::common::config::env_loader;
use crate::common::sys::lifecycle::{Error, Result};
use crate::layers::l7::http::h3;
use crate::resources::certs;
use fancy_log::{LogLevel, log};
use quinn::{ConnectionId, ConnectionIdGenerator};
use rand::Rng;
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tokio::net::UdpSocket;
use tokio::sync::mpsc;
use tokio_rustls::rustls;
pub struct QuicMuxer {
tx: mpsc::Sender<VirtualPacket>,
last_active: Mutex<Instant>,
}
static MUXER_REGISTRY: std::sync::OnceLock<Mutex<HashMap<u16, Arc<QuicMuxer>>>> =
std::sync::OnceLock::new();
#[derive(Debug)]
struct VaneCidGenerator {
port: u16,
}
impl ConnectionIdGenerator for VaneCidGenerator {
fn generate_cid(&mut self) -> ConnectionId {
let mut bytes = [0u8; 8];
rand::rng().fill(&mut bytes);
let cid = ConnectionId::new(&bytes);
session::register_session(
bytes.to_vec(),
SessionAction::Terminate {
muxer_port: self.port,
last_seen: Instant::now(),
_guard: None,
},
);
cid
}
fn cid_len(&self) -> usize {
8
}
fn cid_lifetime(&self) -> Option<Duration> {
None
}
}
impl QuicMuxer {
pub fn get_or_create(port: u16, cert_sni: &str, physical_socket: Arc<UdpSocket>) -> Arc<Self> {
let registry = MUXER_REGISTRY.get_or_init(|| Mutex::new(HashMap::new()));
let mut map = registry.lock().unwrap();
if let Some(muxer) = map.get(&port) {
if let Ok(mut t) = muxer.last_active.lock() {
*t = Instant::now();
}
return muxer.clone();
}
let now = Instant::now();
map.retain(|_, muxer| {
if let Ok(t) = muxer.last_active.lock() {
now.duration_since(*t).as_secs() < 300
} else {
true
}
});
let muxer = Arc::new(Self::new(port, cert_sni, physical_socket));
map.insert(port, muxer.clone());
muxer
}
fn new(port: u16, cert_sni: &str, physical_socket: Arc<UdpSocket>) -> Self {
log(
LogLevel::Info,
&format!("➜ Initializing QUIC Muxer (Virtual Socket) for port {port}"),
);
let channel_cap_str = env_loader::get_env("QUIC_VIRTUAL_CHANNEL_CAPACITY", "1024".to_owned());
let channel_cap = channel_cap_str.parse::<usize>().unwrap_or(1024);
let (tx, rx) = mpsc::channel::<VirtualPacket>(channel_cap);
let cert_id = cert_sni.to_owned();
tokio::spawn(async move {
let mut endpoint_config = quinn::EndpointConfig::default();
endpoint_config.cid_generator(move || Box::new(VaneCidGenerator { port }));
let server_config = match Self::build_server_config(&cert_id) {
Ok(c) => c,
Err(e) => {
log(
LogLevel::Error,
&format!("✗ Failed to build QUIC config: {e}"),
);
return;
}
};
let local_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), port);
let virtual_socket = Arc::new(VirtualUdpSocket::new(rx, physical_socket, local_addr));
let endpoint = match quinn::Endpoint::new_with_abstract_socket(
endpoint_config,
Some(server_config),
virtual_socket,
Arc::new(quinn::TokioRuntime),
) {
Ok(e) => e,
Err(e) => {
log(
LogLevel::Error,
&format!("✗ Failed to create QUIC endpoint: {e}"),
);
return;
}
};
log(
LogLevel::Info,
&format!("✓ QUIC Endpoint initialized (port {port})"),
);
while let Some(incoming) = endpoint.accept().await {
tokio::spawn(async move {
match incoming.await {
Ok(conn) => {
if let Err(e) = h3::handle_connection(conn).await {
log(LogLevel::Error, &format!("✗ H3 Engine Error: {e:#}"));
}
}
Err(e) => log(LogLevel::Warn, &format!("⚠ QUIC Handshake Error: {e}")),
}
});
}
});
Self {
tx,
last_active: Mutex::new(Instant::now()),
}
}
fn build_server_config(cert_id: &str) -> Result<quinn::ServerConfig> {
let cert = certs::arcswap::get_certificate(cert_id)
.ok_or_else(|| Error::Configuration("Certificate not found".to_owned()))?;
let mut crypto = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(cert.certs.clone(), cert.key_clone()?)
.map_err(|e| Error::Tls(e.to_string()))?;
crypto.alpn_protocols = vec![b"h3".to_vec()];
let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(
quinn::crypto::rustls::QuicServerConfig::try_from(crypto)
.map_err(|e| Error::Tls(e.to_string()))?,
));
let mut transport = quinn::TransportConfig::default();
transport.max_idle_timeout(
std::time::Duration::from_secs(30)
.try_into()
.ok()
.map(Some)
.unwrap_or(None),
);
transport.keep_alive_interval(Some(std::time::Duration::from_secs(10)));
server_config.transport_config(Arc::new(transport));
Ok(server_config)
}
pub fn feed_packet(
&self,
data: bytes::Bytes,
src_addr: SocketAddr,
dst_addr: SocketAddr,
) -> Result<()> {
let packet = VirtualPacket {
data,
src_addr,
dst_addr,
};
match self.tx.try_send(packet) {
Ok(_) | Err(mpsc::error::TrySendError::Full(_)) => Ok(()),
Err(mpsc::error::TrySendError::Closed(_)) => Err(
crate::common::sys::lifecycle::Error::System("QUIC Muxer channel closed".to_owned()),
),
}
}
}