anydoor 0.0.1

A tool for forwarding traffic to a remote server
use std::net::SocketAddr;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tracing::{error, info};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::storage::FileStorage;

pub struct Forwarder {
    local_addr: SocketAddr,
    remote_addr: SocketAddr,
}

impl Forwarder {
    pub fn new(local_addr: SocketAddr, remote_addr: SocketAddr) -> Self {
        Self {
            local_addr,
            remote_addr,
        }
    }

    pub async fn start(&self) -> Result<(), Box<dyn std::error::Error>> {
        let listener = TcpListener::bind(self.local_addr).await?;
        info!("Started forwarding from {} to {}", self.local_addr, self.remote_addr);

        loop {
            match listener.accept().await {
                Ok((inbound, _)) => {
                    let remote_addr = self.remote_addr;
                    tokio::spawn(async move {
                        match handle_connection(inbound, remote_addr).await {
                            Ok(_) => info!("Connection handled successfully"),
                            Err(e) => error!("Connection error: {}", e),
                        }
                    });
                }
                Err(e) => error!("Failed to accept connection: {}", e),
            }
        }
    }
}

async fn handle_connection(
    mut inbound: TcpStream,
    remote_addr: SocketAddr,
) -> Result<(), Box<dyn std::error::Error>> {
    let mut outbound = TcpStream::connect(remote_addr).await?;

    let (mut ri, mut wi) = inbound.split();
    let (mut ro, mut wo) = outbound.split();

    let client_to_server = async {
        let mut buffer = [0u8; 1024];
        loop {
            let n = ri.read(&mut buffer).await?;
            if n == 0 {
                return Ok(());
            }
            wo.write_all(&buffer[..n]).await?;
        }
        #[allow(unreachable_code)]
        Ok::<(), std::io::Error>(())
    };

    let server_to_client = async {
        let mut buffer = [0u8; 1024];
        loop {
            let n = ro.read(&mut buffer).await?;
            if n == 0 {
                return Ok(());
            }
            wi.write_all(&buffer[..n]).await?;
        }
        #[allow(unreachable_code)]
        Ok::<(), std::io::Error>(())
    };

    tokio::select! {
        result = client_to_server => {
            if let Err(e) = result {
                error!("Client to server error: {}", e);
            }
        }
        result = server_to_client => {
            if let Err(e) = result {
                error!("Server to client error: {}", e);
            }
        }
    }

    Ok(())
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ForwardRule {
    pub id: String,
    pub local_port: u16,
    pub remote_host: String,
    pub remote_port: u16,
    pub protocol: Protocol,
    pub enabled: bool,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum Protocol {
    TCP,
    UDP,
}

pub struct ForwardManager {
    storage: FileStorage,
    forwarders: Arc<RwLock<HashMap<String, tokio::task::JoinHandle<()>>>>,
}

impl ForwardManager {
    pub fn new(storage: FileStorage) -> Self {
        Self {
            storage,
            forwarders: Arc::new(RwLock::new(HashMap::new())),
        }
    }

    pub async fn add_rule(&self, rule: ForwardRule) -> Result<(), Box<dyn std::error::Error>> {
        let id = rule.id.clone();
        
        if rule.enabled {
            self.start_forwarder(&rule).await?;
        }
        
        self.storage.save_rule(rule).await?;
        Ok(())
    }

    pub async fn remove_rule(&self, id: &str) -> Result<(), Box<dyn std::error::Error>> {
        let mut forwarders = self.forwarders.write().await;
        if let Some(handle) = forwarders.remove(id) {
            handle.abort();
        }
        
        self.storage.remove_rule(id).await?;
        Ok(())
    }

    async fn start_forwarder(&self, rule: &ForwardRule) -> Result<(), Box<dyn std::error::Error>> {
        let local_addr = SocketAddr::from(([0, 0, 0, 0], rule.local_port));
        
        let remote_addr = match (rule.remote_host.as_str(), rule.remote_port) {
            (host, port) if host.parse::<std::net::IpAddr>().is_ok() => {
                format!("{}:{}", host, port)
                    .parse()
                    .map_err(|e| format!("Invalid IP address format: {}", e))?
            },
            (host, port) => {
                use tokio::net::lookup_host;
                let mut addrs = lookup_host(format!("{}:{}", host, port)).await
                    .map_err(|e| format!("Failed to resolve hostname: {}", e))?;
                
                addrs.next()
                    .ok_or_else(|| "No valid address found for hostname".to_string())?
            }
        };

        info!("Starting forwarder from {} to {}", local_addr, remote_addr);
        
        let forwarder = Forwarder::new(local_addr, remote_addr);
        let id = rule.id.clone();
        
        let mut forwarders = self.forwarders.write().await;
        let handle = tokio::spawn(async move {
            if let Err(e) = forwarder.start().await {
                error!("Forwarder error: {}", e);
            }
        });
        
        forwarders.insert(id, handle);
        Ok(())
    }

    pub async fn get_rules(&self) -> Vec<ForwardRule> {
        self.storage.get_rules().await
    }

    pub async fn restore_rules(&self) -> Result<(), Box<dyn std::error::Error>> {
        for rule in self.get_rules().await {
            if rule.enabled {
                self.start_forwarder(&rule).await?;
            }
        }
        Ok(())
    }
}