use std::collections::HashMap;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::os::fd::AsRawFd;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Instant;
use socket2::{Domain, Protocol, Type};
use tokio::io::unix::AsyncFd;
use tokio::net::UdpSocket;
use tokio::sync::mpsc;
use crate::ethernet::{ETH_HEADER_LEN, build_udp_ip_ethernet, prepend_ethernet_header};
struct UdpFlow {
last_active: Instant,
payload_tx: mpsc::Sender<Vec<u8>>,
}
struct UdpProxy {
flows: HashMap<(Ipv4Addr, u16, Ipv4Addr, u16), UdpFlow>,
reply_tx: mpsc::Sender<Vec<u8>>,
gateway_mac: [u8; 6],
}
impl UdpProxy {
fn new(reply_tx: mpsc::Sender<Vec<u8>>, gateway_mac: [u8; 6]) -> Self {
Self {
flows: HashMap::new(),
reply_tx,
gateway_mac,
}
}
fn proxy_udp(&mut self, frame: &[u8], guest_mac: [u8; 6]) {
if frame.len() < ETH_HEADER_LEN + 28 {
return;
}
let ip_start = ETH_HEADER_LEN;
let ihl = ((frame[ip_start] & 0x0F) as usize) * 4;
let l4_start = ip_start + ihl;
if frame.len() < l4_start + 8 {
return;
}
let src_ip = Ipv4Addr::new(
frame[ip_start + 12],
frame[ip_start + 13],
frame[ip_start + 14],
frame[ip_start + 15],
);
let dst_ip = Ipv4Addr::new(
frame[ip_start + 16],
frame[ip_start + 17],
frame[ip_start + 18],
frame[ip_start + 19],
);
let src_port = u16::from_be_bytes([frame[l4_start], frame[l4_start + 1]]);
let dst_port = u16::from_be_bytes([frame[l4_start + 2], frame[l4_start + 3]]);
let udp_len = u16::from_be_bytes([frame[l4_start + 4], frame[l4_start + 5]]) as usize;
if udp_len < 8 || l4_start + udp_len > frame.len() {
return;
}
let payload = frame[l4_start + 8..l4_start + udp_len].to_vec();
let flow_key = (src_ip, src_port, dst_ip, dst_port);
if let Some(flow) = self.flows.get_mut(&flow_key) {
flow.last_active = Instant::now();
if flow.payload_tx.try_send(payload).is_err() {
self.flows.remove(&flow_key);
}
return;
}
let (payload_tx, mut payload_rx) = mpsc::channel::<Vec<u8>>(64);
self.flows.insert(
flow_key,
UdpFlow {
last_active: Instant::now(),
payload_tx,
},
);
let reply_tx = self.reply_tx.clone();
let gateway_mac = self.gateway_mac;
tokio::spawn(async move {
let socket = match UdpSocket::bind("0.0.0.0:0").await {
Ok(s) => s,
Err(e) => {
tracing::warn!("UDP proxy: failed to bind socket: {}", e);
return;
}
};
if let Err(e) = socket
.connect(SocketAddr::V4(SocketAddrV4::new(dst_ip, dst_port)))
.await
{
tracing::warn!("UDP proxy: failed to connect: {}", e);
return;
}
if let Err(e) = socket.send(&payload).await {
tracing::warn!("UDP proxy: send failed: {}", e);
return;
}
let mut buf = vec![0u8; 65535];
loop {
tokio::select! {
msg = payload_rx.recv() => {
match msg {
Some(data) => {
if let Err(e) = socket.send(&data).await {
tracing::debug!("UDP proxy: send failed: {e}");
break;
}
}
None => break, }
}
recv = tokio::time::timeout(
std::time::Duration::from_secs(60),
socket.recv(&mut buf),
) => {
match recv {
Ok(Ok(n)) if n > 0 => {
let reply_frame = build_udp_ip_ethernet(
dst_ip,
src_ip,
dst_port,
src_port,
&buf[..n],
gateway_mac,
guest_mac,
);
if reply_tx.send(reply_frame).await.is_err() {
break;
}
}
_ => break, }
}
}
}
});
}
fn cleanup_stale_flows(&mut self) {
let cutoff = Instant::now()
.checked_sub(std::time::Duration::from_secs(60))
.unwrap();
self.flows.retain(|_, flow| flow.last_active > cutoff);
}
}
struct IcmpProxy {
reply_tx: mpsc::Sender<Vec<u8>>,
gateway_mac: [u8; 6],
disabled: Arc<AtomicBool>,
permission_warned: Arc<AtomicBool>,
}
impl IcmpProxy {
fn new(reply_tx: mpsc::Sender<Vec<u8>>, gateway_mac: [u8; 6]) -> Self {
Self {
reply_tx,
gateway_mac,
disabled: Arc::new(AtomicBool::new(false)),
permission_warned: Arc::new(AtomicBool::new(false)),
}
}
fn proxy_icmp(&self, frame: &[u8], guest_mac: [u8; 6]) {
if self.disabled.load(Ordering::Relaxed) {
return;
}
let ip_start = ETH_HEADER_LEN;
let ihl = ((frame[ip_start] & 0x0F) as usize) * 4;
let icmp_start = ip_start + ihl;
if frame.len() < icmp_start + 8 {
return;
}
let src_ip = Ipv4Addr::new(
frame[ip_start + 12],
frame[ip_start + 13],
frame[ip_start + 14],
frame[ip_start + 15],
);
let dst_ip = Ipv4Addr::new(
frame[ip_start + 16],
frame[ip_start + 17],
frame[ip_start + 18],
frame[ip_start + 19],
);
let icmp_payload = frame[icmp_start..].to_vec();
let reply_tx = self.reply_tx.clone();
let gateway_mac = self.gateway_mac;
let disabled = Arc::clone(&self.disabled);
let permission_warned = Arc::clone(&self.permission_warned);
tokio::spawn(async move {
let icmp_socket = match socket2::Socket::new(
Domain::IPV4,
Type::DGRAM,
Some(Protocol::ICMPV4),
) {
Ok(s) => s,
Err(e) => {
if e.kind() == std::io::ErrorKind::PermissionDenied {
disabled.store(true, Ordering::Relaxed);
if !permission_warned.swap(true, Ordering::Relaxed) {
tracing::warn!(
"ICMP proxy disabled: failed to create ICMP datagram socket: {}",
e
);
}
} else {
tracing::debug!(
"ICMP proxy: failed to create ICMP datagram socket (will retry): {}",
e
);
}
return;
}
};
icmp_socket.set_nonblocking(true).ok();
let dst_addr: SocketAddr = SocketAddrV4::new(dst_ip, 0).into();
if let Err(e) = icmp_socket.send_to(&icmp_payload, &dst_addr.into()) {
tracing::warn!("ICMP proxy: sendto failed: {}", e);
return;
}
let async_fd = match AsyncFd::new(RawSocketWrapper(icmp_socket)) {
Ok(fd) => fd,
Err(e) => {
tracing::warn!("ICMP proxy: AsyncFd failed: {}", e);
return;
}
};
let mut buf = vec![0u8; 65535];
let recv = tokio::time::timeout(std::time::Duration::from_secs(10), async {
loop {
let readable = async_fd.readable().await;
match readable {
Ok(mut guard) => {
match guard.try_io(|inner| {
let n = unsafe {
libc::recv(
inner.get_ref().as_raw_fd(),
buf.as_mut_ptr().cast(),
buf.len(),
0,
)
};
if n < 0 {
Err(std::io::Error::last_os_error())
} else {
Ok(n as usize)
}
}) {
Ok(Ok(n)) if n > 0 => return Ok(n),
Ok(Err(e)) => return Err(e),
_ => continue, }
}
Err(e) => return Err(e),
}
}
})
.await;
match recv {
Ok(Ok(n)) if n > 0 => {
let reply_packet = &buf[..n];
let reply_frame = if looks_like_ipv4_icmp(reply_packet) {
prepend_ethernet_header(reply_packet, guest_mac, gateway_mac)
} else {
build_icmp_ipv4_ethernet(
dst_ip,
src_ip,
reply_packet,
gateway_mac,
guest_mac,
)
};
let _ = reply_tx.send(reply_frame).await;
}
_ => {
tracing::trace!("ICMP proxy: no reply for {} -> {}", src_ip, dst_ip);
}
}
});
}
}
struct RawSocketWrapper(socket2::Socket);
impl AsRawFd for RawSocketWrapper {
fn as_raw_fd(&self) -> std::os::fd::RawFd {
self.0.as_raw_fd()
}
}
fn looks_like_ipv4_icmp(packet: &[u8]) -> bool {
if packet.len() < 20 {
return false;
}
let version = packet[0] >> 4;
let ihl = (packet[0] & 0x0F) as usize * 4;
version == 4 && ihl >= 20 && packet.len() >= ihl && packet[9] == 1
}
fn build_icmp_ipv4_ethernet(
src_ip: Ipv4Addr,
dst_ip: Ipv4Addr,
icmp_payload: &[u8],
src_mac: [u8; 6],
dst_mac: [u8; 6],
) -> Vec<u8> {
let ip_total_len = 20 + icmp_payload.len();
let frame_len = ETH_HEADER_LEN + ip_total_len;
let mut frame = vec![0u8; frame_len];
frame[0..6].copy_from_slice(&dst_mac);
frame[6..12].copy_from_slice(&src_mac);
frame[12..14].copy_from_slice(&0x0800u16.to_be_bytes());
let ip_start = ETH_HEADER_LEN;
let ip = &mut frame[ip_start..ip_start + 20];
ip[0] = 0x45; ip[2..4].copy_from_slice(&(ip_total_len as u16).to_be_bytes());
ip[8] = 64; ip[9] = 1; ip[12..16].copy_from_slice(&src_ip.octets());
ip[16..20].copy_from_slice(&dst_ip.octets());
let ip_cksum = ipv4_checksum(ip);
ip[10..12].copy_from_slice(&ip_cksum.to_be_bytes());
let icmp_start = ETH_HEADER_LEN + 20;
frame[icmp_start..].copy_from_slice(icmp_payload);
if icmp_payload.len() >= 4 {
frame[icmp_start + 2] = 0;
frame[icmp_start + 3] = 0;
let icmp_cksum = internet_checksum(&frame[icmp_start..]);
frame[icmp_start + 2..icmp_start + 4].copy_from_slice(&icmp_cksum.to_be_bytes());
}
frame
}
fn internet_checksum(data: &[u8]) -> u16 {
let mut sum: u32 = 0;
let mut i = 0;
while i + 1 < data.len() {
sum += u32::from(u16::from_be_bytes([data[i], data[i + 1]]));
i += 2;
}
if i < data.len() {
sum += u32::from(data[i]) << 8;
}
while sum > 0xFFFF {
sum = (sum & 0xFFFF) + (sum >> 16);
}
!sum as u16
}
fn ipv4_checksum(header: &[u8]) -> u16 {
let mut sum: u32 = 0;
let mut i = 0;
while i + 1 < header.len() {
if i != 10 {
sum += u32::from(u16::from_be_bytes([header[i], header[i + 1]]));
}
i += 2;
}
while sum > 0xFFFF {
sum = (sum & 0xFFFF) + (sum >> 16);
}
!sum as u16
}
pub struct SocketProxy {
icmp: IcmpProxy,
udp: UdpProxy,
inbound: super::inbound_relay::InboundRelay,
reply_tx: mpsc::Sender<Vec<u8>>,
}
impl SocketProxy {
#[must_use]
pub fn new(
gateway_ip: Ipv4Addr,
gateway_mac: [u8; 6],
guest_ip: Ipv4Addr,
reply_tx: mpsc::Sender<Vec<u8>>,
) -> Self {
let inbound = super::inbound_relay::InboundRelay::new(
reply_tx.clone(),
gateway_mac,
gateway_ip,
guest_ip,
);
Self {
icmp: IcmpProxy::new(reply_tx.clone(), gateway_mac),
udp: UdpProxy::new(reply_tx.clone(), gateway_mac),
inbound,
reply_tx,
}
}
#[must_use]
pub fn reply_sender(&self) -> mpsc::Sender<Vec<u8>> {
self.reply_tx.clone()
}
pub fn handle_outbound(&mut self, frame: &[u8], guest_mac: [u8; 6]) {
if frame.len() < ETH_HEADER_LEN + 20 {
return;
}
if self.inbound.try_handle_reply(frame, guest_mac) {
return;
}
let protocol = frame[ETH_HEADER_LEN + 9];
match protocol {
1 => self.icmp.proxy_icmp(frame, guest_mac),
17 => self.udp.proxy_udp(frame, guest_mac),
_ => {
tracing::trace!("Socket proxy: dropping protocol {}", protocol);
}
}
}
pub(crate) fn handle_inbound_command(
&mut self,
cmd: super::inbound_relay::InboundCommand,
guest_mac: [u8; 6],
) {
use super::inbound_relay::InboundCommand;
if let InboundCommand::UdpReceived {
host_port,
data,
reply_tx,
..
} = cmd
{
self.inbound
.inject_udp(host_port, &data, reply_tx, guest_mac);
}
}
pub fn maintenance(&mut self) {
self.udp.cleanup_stale_flows();
self.inbound.cleanup();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_socket_proxy_creation() {
let (tx, _rx) = mpsc::channel(16);
let gw_ip = Ipv4Addr::new(192, 168, 64, 1);
let gw_mac = [0x02, 0xAB, 0xCD, 0x00, 0x00, 0x01];
let guest_ip = Ipv4Addr::new(192, 168, 64, 2);
let _proxy = SocketProxy::new(gw_ip, gw_mac, guest_ip, tx);
}
#[test]
fn test_socket_proxy_drops_unknown_protocol() {
let (tx, _rx) = mpsc::channel(16);
let gw_ip = Ipv4Addr::new(192, 168, 64, 1);
let gw_mac = [0x02, 0xAB, 0xCD, 0x00, 0x00, 0x01];
let guest_ip = Ipv4Addr::new(192, 168, 64, 2);
let mut proxy = SocketProxy::new(gw_ip, gw_mac, guest_ip, tx);
let mut frame = vec![0u8; ETH_HEADER_LEN + 20];
frame[12..14].copy_from_slice(&0x0800u16.to_be_bytes());
frame[ETH_HEADER_LEN] = 0x45; frame[ETH_HEADER_LEN + 9] = 50;
let guest_mac = [0x02, 0x00, 0x00, 0x00, 0x00, 0x99];
proxy.handle_outbound(&frame, guest_mac);
}
#[test]
fn test_socket_proxy_ignores_short_frames() {
let (tx, _rx) = mpsc::channel(16);
let gw_ip = Ipv4Addr::new(192, 168, 64, 1);
let gw_mac = [0x02, 0xAB, 0xCD, 0x00, 0x00, 0x01];
let guest_ip = Ipv4Addr::new(192, 168, 64, 2);
let mut proxy = SocketProxy::new(gw_ip, gw_mac, guest_ip, tx);
let guest_mac = [0x02, 0x00, 0x00, 0x00, 0x00, 0x99];
proxy.handle_outbound(&[0u8; 10], guest_mac);
}
#[test]
fn test_udp_proxy_cleanup_stale_flows() {
let (tx, _rx) = mpsc::channel(16);
let gw_ip = Ipv4Addr::new(192, 168, 64, 1);
let gw_mac = [0x02, 0xAB, 0xCD, 0x00, 0x00, 0x01];
let mut proxy = UdpProxy::new(tx, gw_mac);
let key = (
Ipv4Addr::new(192, 168, 64, 2),
1234,
Ipv4Addr::new(8, 8, 8, 8),
53,
);
proxy.flows.insert(
key,
UdpFlow {
last_active: Instant::now() - std::time::Duration::from_secs(120),
payload_tx: mpsc::channel(1).0,
},
);
assert_eq!(proxy.flows.len(), 1);
proxy.cleanup_stale_flows();
assert_eq!(proxy.flows.len(), 0, "Stale flow should be cleaned up");
}
#[tokio::test]
async fn test_reply_channel_flow() {
let (tx, mut rx) = mpsc::channel(16);
let gw_ip = Ipv4Addr::new(192, 168, 64, 1);
let gw_mac = [0x02, 0xAB, 0xCD, 0x00, 0x00, 0x01];
let test_frame = vec![0xDE, 0xAD, 0xBE, 0xEF];
tx.send(test_frame.clone()).await.unwrap();
let received = rx.recv().await.unwrap();
assert_eq!(received, test_frame);
let guest_ip = Ipv4Addr::new(192, 168, 64, 2);
let _proxy = SocketProxy::new(gw_ip, gw_mac, guest_ip, tx);
}
}