use std::net::{SocketAddr, IpAddr};
use std::collections::HashMap;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tokio::sync::mpsc::Sender;
use uuid::Uuid;
use tokio::net::UdpSocket;
use std::sync::Arc;
#[allow(unused_imports)]
use relay_core_api::flow::{Flow, FlowUpdate, NetworkInfo, TransportProtocol, Layer, UdpLayer};
#[allow(unused_imports)]
use chrono::Utc;
#[cfg(target_os = "linux")]
use crate::capture::linux_tproxy::LinuxTproxy;
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct UdpSessionKey {
pub src_ip: IpAddr,
pub src_port: u16,
pub dst_ip: IpAddr,
pub dst_port: u16,
}
impl UdpSessionKey {
pub fn new(src: SocketAddr, dst: SocketAddr) -> Self {
Self {
src_ip: src.ip(),
src_port: src.port(),
dst_ip: dst.ip(),
dst_port: dst.port(),
}
}
}
#[derive(Debug, Clone)]
pub struct UdpSession {
pub flow_id: Uuid,
pub key: UdpSessionKey,
pub created_at: Instant,
pub last_activity: Arc<RwLock<Instant>>,
pub packet_count: Arc<AtomicUsize>,
pub bytes_transferred: Arc<AtomicUsize>,
#[cfg(target_os = "linux")]
pub upstream_socket: Option<Arc<UdpSocket>>, #[cfg(target_os = "linux")]
pub downstream_socket: Option<Arc<UdpSocket>>, }
pub struct UdpSessionManager {
sessions: RwLock<HashMap<UdpSessionKey, UdpSession>>,
idle_timeout: Duration,
}
impl UdpSessionManager {
pub fn new(idle_timeout: Duration) -> Self {
Self {
sessions: RwLock::new(HashMap::new()),
idle_timeout,
}
}
pub async fn get_or_create_session(&self, src: SocketAddr, dst: SocketAddr) -> std::io::Result<(UdpSession, bool)> {
let key = UdpSessionKey::new(src, dst);
{
let sessions = self.sessions.read().await;
if let Some(session) = sessions.get(&key) {
let mut last = session.last_activity.write().await;
*last = Instant::now();
session.packet_count.fetch_add(1, Ordering::Relaxed);
return Ok((session.clone(), false));
}
}
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.get(&key) {
let mut last = session.last_activity.write().await;
*last = Instant::now();
session.packet_count.fetch_add(1, Ordering::Relaxed);
return Ok((session.clone(), false));
}
#[cfg(target_os = "linux")]
let (upstream, downstream) = {
let up = LinuxTproxy::create_transparent_udp_socket(src)?;
up.connect(dst).await?;
let down = LinuxTproxy::create_transparent_udp_socket(dst)?;
down.connect(src).await?;
(Some(Arc::new(up)), Some(Arc::new(down)))
};
let session = UdpSession {
flow_id: Uuid::new_v4(),
key: key.clone(),
created_at: Instant::now(),
last_activity: Arc::new(RwLock::new(Instant::now())),
packet_count: Arc::new(AtomicUsize::new(1)),
bytes_transferred: Arc::new(AtomicUsize::new(0)),
#[cfg(target_os = "linux")]
upstream_socket: upstream,
#[cfg(target_os = "linux")]
downstream_socket: downstream,
};
#[cfg(target_os = "linux")]
if let (Some(up), Some(down)) = (&session.upstream_socket, &session.downstream_socket) {
let up_clone = up.clone();
let down_clone = down.clone();
let last_activity = session.last_activity.clone();
let bytes_transferred = session.bytes_transferred.clone();
tokio::spawn(async move {
let mut buf = [0u8; 65535];
loop {
match up_clone.recv(&mut buf).await {
Ok(n) => {
if let Ok(mut last) = last_activity.try_write() {
*last = Instant::now();
}
bytes_transferred.fetch_add(n, Ordering::Relaxed);
if let Err(e) = down_clone.send(&buf[..n]).await {
tracing::debug!("UDP downstream send error: {}", e);
break;
}
}
Err(e) => {
tracing::debug!("UDP upstream recv error: {}", e);
break;
}
}
}
});
}
sessions.insert(key, session.clone());
Ok((session, true))
}
pub async fn cleanup_idle_sessions(&self) -> Vec<Uuid> {
let mut sessions = self.sessions.write().await;
let now = Instant::now();
let mut removed_ids = Vec::new();
let mut keys_to_remove = Vec::new();
for (key, session) in sessions.iter() {
let last = *session.last_activity.read().await;
if now.duration_since(last) > self.idle_timeout {
removed_ids.push(session.flow_id);
keys_to_remove.push(key.clone());
}
}
for key in keys_to_remove {
sessions.remove(&key);
}
removed_ids
}
}
pub struct UdpProxy {
socket: Arc<UdpSocket>,
#[allow(dead_code)]
session_manager: Arc<UdpSessionManager>,
}
impl UdpProxy {
pub fn new(socket: UdpSocket, idle_timeout: Duration) -> Self {
Self {
socket: Arc::new(socket),
session_manager: Arc::new(UdpSessionManager::new(idle_timeout)),
}
}
pub async fn run(&self, on_flow: Sender<FlowUpdate>) -> crate::error::Result<()>
{
let mut buf = [0u8; 65535];
#[cfg(target_os = "linux")]
{
LinuxTproxy::enable_tproxy(&self.socket)?;
loop {
let (len, src_addr, orig_dst) = match LinuxTproxy::recv_original_dst(&self.socket, &mut buf).await {
Ok(res) => res,
Err(e) => {
tracing::error!("UDP TPROXY recv error: {}", e);
continue;
}
};
if let Some(dst_addr) = orig_dst {
match self.session_manager.get_or_create_session(src_addr, dst_addr).await {
Ok((session, is_new)) => {
if is_new {
let flow = Flow {
id: session.flow_id,
start_time: Utc::now(),
end_time: None,
network: NetworkInfo {
client_ip: src_addr.ip().to_string(),
client_port: src_addr.port(),
server_ip: dst_addr.ip().to_string(),
server_port: dst_addr.port(),
protocol: TransportProtocol::UDP,
tls: false,
tls_version: None,
sni: None,
},
layer: Layer::Udp(UdpLayer {
payload_size: len,
packet_count: 1,
}),
tags: vec![],
meta: HashMap::new(),
};
if on_flow.try_send(FlowUpdate::Full(Box::new(flow))).is_err() {
crate::metrics::inc_flows_dropped();
}
}
if let Some(upstream) = &session.upstream_socket {
if let Err(e) = upstream.send(&buf[..len]).await {
tracing::debug!("UDP upstream send error: {}", e);
} else {
session.bytes_transferred.fetch_add(len, Ordering::Relaxed);
}
}
}
Err(e) => {
tracing::warn!("Failed to create UDP session: {}", e);
}
}
}
}
}
#[cfg(not(target_os = "linux"))]
{
let _ = on_flow;
loop {
let (_len, _src_addr) = self.socket.recv_from(&mut buf).await?;
}
}
}
}