use std::net::SocketAddr;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tracing::{error, info};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::storage::FileStorage;
pub struct Forwarder {
local_addr: SocketAddr,
remote_addr: SocketAddr,
}
impl Forwarder {
pub fn new(local_addr: SocketAddr, remote_addr: SocketAddr) -> Self {
Self {
local_addr,
remote_addr,
}
}
pub async fn start(&self) -> Result<(), Box<dyn std::error::Error>> {
let listener = TcpListener::bind(self.local_addr).await?;
info!("Started forwarding from {} to {}", self.local_addr, self.remote_addr);
loop {
match listener.accept().await {
Ok((inbound, _)) => {
let remote_addr = self.remote_addr;
tokio::spawn(async move {
match handle_connection(inbound, remote_addr).await {
Ok(_) => info!("Connection handled successfully"),
Err(e) => error!("Connection error: {}", e),
}
});
}
Err(e) => error!("Failed to accept connection: {}", e),
}
}
}
}
async fn handle_connection(
mut inbound: TcpStream,
remote_addr: SocketAddr,
) -> Result<(), Box<dyn std::error::Error>> {
let mut outbound = TcpStream::connect(remote_addr).await?;
let (mut ri, mut wi) = inbound.split();
let (mut ro, mut wo) = outbound.split();
let client_to_server = async {
let mut buffer = [0u8; 1024];
loop {
let n = ri.read(&mut buffer).await?;
if n == 0 {
return Ok(());
}
wo.write_all(&buffer[..n]).await?;
}
#[allow(unreachable_code)]
Ok::<(), std::io::Error>(())
};
let server_to_client = async {
let mut buffer = [0u8; 1024];
loop {
let n = ro.read(&mut buffer).await?;
if n == 0 {
return Ok(());
}
wi.write_all(&buffer[..n]).await?;
}
#[allow(unreachable_code)]
Ok::<(), std::io::Error>(())
};
tokio::select! {
result = client_to_server => {
if let Err(e) = result {
error!("Client to server error: {}", e);
}
}
result = server_to_client => {
if let Err(e) = result {
error!("Server to client error: {}", e);
}
}
}
Ok(())
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ForwardRule {
pub id: String,
pub local_port: u16,
pub remote_host: String,
pub remote_port: u16,
pub protocol: Protocol,
pub enabled: bool,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum Protocol {
TCP,
UDP,
}
pub struct ForwardManager {
storage: FileStorage,
forwarders: Arc<RwLock<HashMap<String, tokio::task::JoinHandle<()>>>>,
}
impl ForwardManager {
pub fn new(storage: FileStorage) -> Self {
Self {
storage,
forwarders: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn add_rule(&self, rule: ForwardRule) -> Result<(), Box<dyn std::error::Error>> {
let id = rule.id.clone();
if rule.enabled {
self.start_forwarder(&rule).await?;
}
self.storage.save_rule(rule).await?;
Ok(())
}
pub async fn remove_rule(&self, id: &str) -> Result<(), Box<dyn std::error::Error>> {
let mut forwarders = self.forwarders.write().await;
if let Some(handle) = forwarders.remove(id) {
handle.abort();
}
self.storage.remove_rule(id).await?;
Ok(())
}
async fn start_forwarder(&self, rule: &ForwardRule) -> Result<(), Box<dyn std::error::Error>> {
let local_addr = SocketAddr::from(([0, 0, 0, 0], rule.local_port));
let remote_addr = match (rule.remote_host.as_str(), rule.remote_port) {
(host, port) if host.parse::<std::net::IpAddr>().is_ok() => {
format!("{}:{}", host, port)
.parse()
.map_err(|e| format!("Invalid IP address format: {}", e))?
},
(host, port) => {
use tokio::net::lookup_host;
let mut addrs = lookup_host(format!("{}:{}", host, port)).await
.map_err(|e| format!("Failed to resolve hostname: {}", e))?;
addrs.next()
.ok_or_else(|| "No valid address found for hostname".to_string())?
}
};
info!("Starting forwarder from {} to {}", local_addr, remote_addr);
let forwarder = Forwarder::new(local_addr, remote_addr);
let id = rule.id.clone();
let mut forwarders = self.forwarders.write().await;
let handle = tokio::spawn(async move {
if let Err(e) = forwarder.start().await {
error!("Forwarder error: {}", e);
}
});
forwarders.insert(id, handle);
Ok(())
}
pub async fn get_rules(&self) -> Vec<ForwardRule> {
self.storage.get_rules().await
}
pub async fn restore_rules(&self) -> Result<(), Box<dyn std::error::Error>> {
for rule in self.get_rules().await {
if rule.enabled {
self.start_forwarder(&rule).await?;
}
}
Ok(())
}
}