rustgate-proxy 0.2.0

MITM-capable HTTP/HTTPS proxy with WebSocket C2 tunneling (SOCKS5, reverse TCP)
Documentation
use clap::{Parser, Subcommand};
use rustgate::cert::CertificateAuthority;
use rustgate::handler::LoggingHandler;
use rustgate::proxy::ProxyState;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::net::TcpListener;
use tracing::{info, warn};

const BANNER: &str = "\
WARNING: This tool is for authorized security research only.
Unauthorized use may violate applicable laws. Use responsibly.
";

#[derive(Parser, Debug)]
#[command(name = "rustgate", about = "MITM proxy and C2 tunnel toolkit")]
struct Cli {
    #[command(subcommand)]
    command: Option<Commands>,

    /// Address to listen on (proxy mode)
    #[arg(long, default_value = "127.0.0.1")]
    host: String,

    /// Port to listen on (proxy mode)
    #[arg(short, long, default_value_t = 8080)]
    port: u16,

    /// Enable MITM mode (TLS interception)
    #[arg(long)]
    mitm: bool,
}

#[derive(Subcommand, Debug)]
enum Commands {
    /// Run as C2 server (accept WebSocket clients via mTLS)
    Server {
        #[arg(long, default_value = "0.0.0.0")]
        host: String,
        #[arg(short, long, default_value_t = 4443)]
        port: u16,
        /// Hostname/IP for the server certificate (clients connect to this name)
        #[arg(long)]
        server_name: String,
        /// Path to CA directory (required — each deployment should use its own CA)
        #[arg(long)]
        ca_dir: PathBuf,
    },
    /// Run as C2 client (connect to server via mTLS)
    Client {
        /// Server WebSocket URL (e.g. wss://server.example.com:4443)
        #[arg(long)]
        server_url: String,
        /// Path to client certificate PEM
        #[arg(long)]
        cert: PathBuf,
        /// Path to client private key PEM
        #[arg(long)]
        key: PathBuf,
        /// Path to CA cert PEM for verifying server
        #[arg(long)]
        ca_cert: PathBuf,
    },
    /// Generate a client certificate signed by the CA
    GenClientCert {
        /// Common name for the client certificate
        #[arg(long, default_value = "rustgate-client")]
        cn: String,
        /// Output directory for cert and key PEM files
        #[arg(long, default_value = ".")]
        out_dir: PathBuf,
        /// Path to CA directory (required — must match the server's CA)
        #[arg(long)]
        ca_dir: PathBuf,
    },
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    tracing_subscriber::fmt()
        .with_env_filter(
            tracing_subscriber::EnvFilter::from_default_env()
                .add_directive("rustgate=info".parse().unwrap()),
        )
        .init();

    let cli = Cli::parse();

    eprintln!("{BANNER}");

    match cli.command {
        None => run_proxy(cli.host, cli.port, cli.mitm).await,
        Some(Commands::Server {
            host,
            port,
            server_name,
            ca_dir,
        }) => run_server(host, port, server_name, ca_dir).await,
        Some(Commands::Client {
            server_url,
            cert,
            key,
            ca_cert,
        }) => run_client(server_url, cert, key, ca_cert).await,
        Some(Commands::GenClientCert {
            cn,
            out_dir,
            ca_dir,
        }) => run_gen_client_cert(cn, out_dir, ca_dir).await,
    }
}

async fn run_proxy(
    host: String,
    port: u16,
    mitm: bool,
) -> Result<(), Box<dyn std::error::Error>> {
    let listen_addr = format!("{host}:{port}");

    let ca = Arc::new(CertificateAuthority::new().await?);

    if mitm {
        let ca_path = CertificateAuthority::ca_cert_path()?;
        info!(
            "MITM mode enabled. Install CA cert: {}",
            ca_path.display()
        );
    }

    let state = Arc::new(ProxyState {
        ca,
        mitm,
        handler: Arc::new(LoggingHandler),
    });

    if let Ok(ip) = host.parse::<std::net::IpAddr>() {
        if !ip.is_loopback() {
            warn!(
                "Binding to non-loopback address ({host}). \
                 No authentication is configured — this proxy may be accessible from the network."
            );
        }
    }

    let listener = TcpListener::bind(&listen_addr).await?;
    info!("RustGate proxy listening on {listen_addr}");

    loop {
        let (stream, addr) = listener.accept().await?;
        let state = state.clone();
        tokio::spawn(async move {
            rustgate::proxy::handle_connection(stream, addr, state).await;
        });
    }
}

async fn run_server(
    host: String,
    port: u16,
    server_name: String,
    ca_dir: PathBuf,
) -> Result<(), Box<dyn std::error::Error>> {
    let ca = Arc::new(CertificateAuthority::with_dir(ca_dir).await?);
    rustgate::c2::server::run(&host, port, &server_name, ca).await?;
    Ok(())
}

async fn run_client(
    server_url: String,
    cert: PathBuf,
    key: PathBuf,
    ca_cert: PathBuf,
) -> Result<(), Box<dyn std::error::Error>> {
    rustgate::c2::client::run(
        &server_url,
        cert.to_str().unwrap_or_default(),
        key.to_str().unwrap_or_default(),
        ca_cert.to_str().unwrap_or_default(),
    )
    .await?;
    Ok(())
}

async fn run_gen_client_cert(
    cn: String,
    out_dir: PathBuf,
    ca_dir: PathBuf,
) -> Result<(), Box<dyn std::error::Error>> {
    let ca = CertificateAuthority::with_dir(ca_dir).await?;

    // Sanitize CN for use as a filename — reject path separators and traversals
    if cn.contains('/') || cn.contains('\\') || cn.contains('\0') || cn.starts_with('.') {
        return Err("CN must not contain path separators or start with '.'".into());
    }

    let (cert_pem, key_pem) = ca.generate_client_cert(&cn)?;

    tokio::fs::create_dir_all(&out_dir).await?;
    let cert_path = out_dir.join(format!("{cn}.pem"));
    let key_path = out_dir.join(format!("{cn}-key.pem"));

    // Reject symlinks to prevent arbitrary file clobber
    #[cfg(unix)]
    {
        for path in [&cert_path, &key_path] {
            if let Ok(meta) = tokio::fs::symlink_metadata(path).await {
                if meta.file_type().is_symlink() {
                    return Err(format!("Refusing to overwrite symlink: {}", path.display()).into());
                }
            }
        }
    }

    tokio::fs::write(&cert_path, &cert_pem).await?;

    // Write private key with restricted permissions
    #[cfg(unix)]
    {
        use tokio::io::AsyncWriteExt;
        // Try create_new first (no overwrite, 0600 from creation)
        let new_file = tokio::fs::OpenOptions::new()
            .write(true)
            .create_new(true)
            .mode(0o600)
            .open(&key_path)
            .await;
        if let Ok(f) = new_file {
            let mut writer = tokio::io::BufWriter::new(f);
            writer.write_all(key_pem.as_bytes()).await?;
            writer.flush().await?;
        } else {
            // File already exists — overwrite and force-set permissions
            tokio::fs::write(&key_path, &key_pem).await?;
            use std::os::unix::fs::PermissionsExt;
            tokio::fs::set_permissions(&key_path, std::fs::Permissions::from_mode(0o600)).await?;
        }
    }
    #[cfg(not(unix))]
    {
        tokio::fs::write(&key_path, &key_pem).await?;
    }

    info!("Client certificate generated:");
    info!("  Cert: {}", cert_path.display());
    info!("  Key:  {}", key_path.display());
    info!("  CN:   {cn}");

    Ok(())
}