use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream, UdpSocket};
use tokio::sync::{RwLock, mpsc};
use tokio::task::JoinHandle;
use crate::error::{NetError, Result};
#[derive(Debug, Clone)]
pub struct PortForwardRule {
pub host_addr: SocketAddr,
pub guest_addr: SocketAddr,
pub protocol: Protocol,
}
impl PortForwardRule {
#[must_use]
pub fn tcp(host_addr: SocketAddr, guest_addr: SocketAddr) -> Self {
Self {
host_addr,
guest_addr,
protocol: Protocol::Tcp,
}
}
#[must_use]
pub fn udp(host_addr: SocketAddr, guest_addr: SocketAddr) -> Self {
Self {
host_addr,
guest_addr,
protocol: Protocol::Udp,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Protocol {
Tcp,
Udp,
}
struct ActiveForwarder {
shutdown_tx: mpsc::Sender<()>,
handle: JoinHandle<()>,
}
pub struct PortForwarder {
rules: Vec<PortForwardRule>,
active: Arc<RwLock<HashMap<String, ActiveForwarder>>>,
running: bool,
}
impl PortForwarder {
#[must_use]
pub fn new() -> Self {
Self {
rules: Vec::new(),
active: Arc::new(RwLock::new(HashMap::new())),
running: false,
}
}
pub fn add_rule(&mut self, rule: PortForwardRule) {
self.rules.push(rule);
}
pub fn remove_rule(&mut self, host_addr: SocketAddr) {
self.rules.retain(|r| r.host_addr != host_addr);
}
#[must_use]
pub fn rules(&self) -> &[PortForwardRule] {
&self.rules
}
pub async fn start(&mut self) -> Result<()> {
if self.running {
return Ok(());
}
for rule in &self.rules {
self.start_forwarder(rule.clone()).await?;
}
self.running = true;
tracing::info!("Port forwarder started with {} rules", self.rules.len());
Ok(())
}
pub async fn stop(&mut self) {
if !self.running {
return;
}
let mut active = self.active.write().await;
for (addr, forwarder) in active.drain() {
let _ = forwarder.shutdown_tx.send(()).await;
forwarder.handle.abort();
tracing::debug!("Stopped forwarder for {}", addr);
}
self.running = false;
tracing::info!("Port forwarder stopped");
}
async fn start_forwarder(&self, rule: PortForwardRule) -> Result<()> {
let key = rule.host_addr.to_string();
{
let active = self.active.read().await;
if active.contains_key(&key) {
return Ok(());
}
}
let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
let active = Arc::clone(&self.active);
let handle = match rule.protocol {
Protocol::Tcp => {
let listener = TcpListener::bind(rule.host_addr).await.map_err(|e| {
NetError::io(std::io::Error::new(
e.kind(),
format!("failed to bind {}: {}", rule.host_addr, e),
))
})?;
tracing::info!(
"TCP port forward: {} -> {}",
rule.host_addr,
rule.guest_addr
);
tokio::spawn(tcp_forward_loop(listener, rule.guest_addr, shutdown_rx))
}
Protocol::Udp => {
let socket = UdpSocket::bind(rule.host_addr).await.map_err(|e| {
NetError::io(std::io::Error::new(
e.kind(),
format!("failed to bind {}: {}", rule.host_addr, e),
))
})?;
tracing::info!(
"UDP port forward: {} -> {}",
rule.host_addr,
rule.guest_addr
);
tokio::spawn(udp_forward_loop(socket, rule.guest_addr, shutdown_rx))
}
};
let mut active_guard = active.write().await;
active_guard.insert(
key,
ActiveForwarder {
shutdown_tx,
handle,
},
);
Ok(())
}
pub async fn add_and_start(&mut self, rule: PortForwardRule) -> Result<()> {
if self.running {
self.start_forwarder(rule.clone()).await?;
}
self.rules.push(rule);
Ok(())
}
pub async fn remove_and_stop(&mut self, host_addr: SocketAddr) {
let key = host_addr.to_string();
let mut active = self.active.write().await;
if let Some(forwarder) = active.remove(&key) {
let _ = forwarder.shutdown_tx.send(()).await;
forwarder.handle.abort();
}
self.rules.retain(|r| r.host_addr != host_addr);
}
pub async fn active_count(&self) -> usize {
self.active.read().await.len()
}
#[must_use]
pub fn is_running(&self) -> bool {
self.running
}
}
impl Default for PortForwarder {
fn default() -> Self {
Self::new()
}
}
async fn tcp_forward_loop(
listener: TcpListener,
guest_addr: SocketAddr,
mut shutdown_rx: mpsc::Receiver<()>,
) {
loop {
tokio::select! {
_ = shutdown_rx.recv() => {
tracing::debug!("TCP forwarder shutdown");
break;
}
result = listener.accept() => {
match result {
Ok((client, peer_addr)) => {
tracing::debug!("TCP connection from {} -> {}", peer_addr, guest_addr);
tokio::spawn(handle_tcp_connection(client, guest_addr));
}
Err(e) => {
tracing::warn!("TCP accept error: {}", e);
}
}
}
}
}
}
async fn handle_tcp_connection(mut client: TcpStream, guest_addr: SocketAddr) {
let mut guest = match TcpStream::connect(guest_addr).await {
Ok(stream) => stream,
Err(e) => {
tracing::warn!("Failed to connect to guest {}: {}", guest_addr, e);
return;
}
};
let (mut client_read, mut client_write) = client.split();
let (mut guest_read, mut guest_write) = guest.split();
let client_to_guest = async {
let mut buf = vec![0u8; 8192];
loop {
let n = match client_read.read(&mut buf).await {
Ok(0) => break,
Ok(n) => n,
Err(_) => break,
};
if guest_write.write_all(&buf[..n]).await.is_err() {
break;
}
}
};
let guest_to_client = async {
let mut buf = vec![0u8; 8192];
loop {
let n = match guest_read.read(&mut buf).await {
Ok(0) => break,
Ok(n) => n,
Err(_) => break,
};
if client_write.write_all(&buf[..n]).await.is_err() {
break;
}
}
};
tokio::select! {
() = client_to_guest => {}
() = guest_to_client => {}
}
}
async fn udp_forward_loop(
socket: UdpSocket,
guest_addr: SocketAddr,
mut shutdown_rx: mpsc::Receiver<()>,
) {
let socket = Arc::new(socket);
let client_addr: Arc<RwLock<Option<SocketAddr>>> = Arc::new(RwLock::new(None));
let mut buf = vec![0u8; 65535];
loop {
tokio::select! {
_ = shutdown_rx.recv() => {
tracing::debug!("UDP forwarder shutdown");
break;
}
result = socket.recv_from(&mut buf) => {
match result {
Ok((len, peer_addr)) => {
if peer_addr == guest_addr {
let client = client_addr.read().await;
if let Some(addr) = *client {
let _ = socket.send_to(&buf[..len], addr).await;
}
} else {
{
let mut client = client_addr.write().await;
*client = Some(peer_addr);
}
let _ = socket.send_to(&buf[..len], guest_addr).await;
}
}
Err(e) => {
tracing::warn!("UDP recv error: {}", e);
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, SocketAddrV4};
#[test]
fn test_port_forward_rule_tcp() {
let host = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080));
let guest = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 64, 2), 80));
let rule = PortForwardRule::tcp(host, guest);
assert_eq!(rule.host_addr, host);
assert_eq!(rule.guest_addr, guest);
assert_eq!(rule.protocol, Protocol::Tcp);
}
#[test]
fn test_port_forward_rule_udp() {
let host = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 5353));
let guest = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 64, 2), 53));
let rule = PortForwardRule::udp(host, guest);
assert_eq!(rule.protocol, Protocol::Udp);
}
#[test]
fn test_port_forwarder_add_remove() {
let mut forwarder = PortForwarder::new();
let host = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 8080));
let guest = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 64, 2), 80));
forwarder.add_rule(PortForwardRule::tcp(host, guest));
assert_eq!(forwarder.rules().len(), 1);
forwarder.remove_rule(host);
assert!(forwarder.rules().is_empty());
}
#[tokio::test]
async fn test_port_forwarder_not_running() {
let forwarder = PortForwarder::new();
assert!(!forwarder.is_running());
assert_eq!(forwarder.active_count().await, 0);
}
}