use clap::{Parser, Subcommand};
use rustgate::cert::CertificateAuthority;
use rustgate::handler::LoggingHandler;
use rustgate::proxy::ProxyState;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::net::TcpListener;
use tracing::{info, warn};
const BANNER: &str = "\
WARNING: This tool is for authorized security research only.
Unauthorized use may violate applicable laws. Use responsibly.
";
#[derive(Parser, Debug)]
#[command(name = "rustgate", about = "MITM proxy and C2 tunnel toolkit")]
struct Cli {
#[command(subcommand)]
command: Option<Commands>,
#[arg(long, default_value = "127.0.0.1")]
host: String,
#[arg(short, long, default_value_t = 8080)]
port: u16,
#[arg(long)]
mitm: bool,
}
#[derive(Subcommand, Debug)]
enum Commands {
Server {
#[arg(long, default_value = "0.0.0.0")]
host: String,
#[arg(short, long, default_value_t = 4443)]
port: u16,
#[arg(long)]
server_name: String,
#[arg(long)]
ca_dir: PathBuf,
},
Client {
#[arg(long)]
server_url: String,
#[arg(long)]
cert: PathBuf,
#[arg(long)]
key: PathBuf,
#[arg(long)]
ca_cert: PathBuf,
},
GenClientCert {
#[arg(long, default_value = "rustgate-client")]
cn: String,
#[arg(long, default_value = ".")]
out_dir: PathBuf,
#[arg(long)]
ca_dir: PathBuf,
},
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::from_default_env()
.add_directive("rustgate=info".parse().unwrap()),
)
.init();
let cli = Cli::parse();
eprintln!("{BANNER}");
match cli.command {
None => run_proxy(cli.host, cli.port, cli.mitm).await,
Some(Commands::Server {
host,
port,
server_name,
ca_dir,
}) => run_server(host, port, server_name, ca_dir).await,
Some(Commands::Client {
server_url,
cert,
key,
ca_cert,
}) => run_client(server_url, cert, key, ca_cert).await,
Some(Commands::GenClientCert {
cn,
out_dir,
ca_dir,
}) => run_gen_client_cert(cn, out_dir, ca_dir).await,
}
}
async fn run_proxy(
host: String,
port: u16,
mitm: bool,
) -> Result<(), Box<dyn std::error::Error>> {
let listen_addr = format!("{host}:{port}");
let ca = Arc::new(CertificateAuthority::new().await?);
if mitm {
let ca_path = CertificateAuthority::ca_cert_path()?;
info!(
"MITM mode enabled. Install CA cert: {}",
ca_path.display()
);
}
let state = Arc::new(ProxyState {
ca,
mitm,
handler: Arc::new(LoggingHandler),
});
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
if !ip.is_loopback() {
warn!(
"Binding to non-loopback address ({host}). \
No authentication is configured — this proxy may be accessible from the network."
);
}
}
let listener = TcpListener::bind(&listen_addr).await?;
info!("RustGate proxy listening on {listen_addr}");
loop {
let (stream, addr) = listener.accept().await?;
let state = state.clone();
tokio::spawn(async move {
rustgate::proxy::handle_connection(stream, addr, state).await;
});
}
}
async fn run_server(
host: String,
port: u16,
server_name: String,
ca_dir: PathBuf,
) -> Result<(), Box<dyn std::error::Error>> {
let ca = Arc::new(CertificateAuthority::with_dir(ca_dir).await?);
rustgate::c2::server::run(&host, port, &server_name, ca).await?;
Ok(())
}
async fn run_client(
server_url: String,
cert: PathBuf,
key: PathBuf,
ca_cert: PathBuf,
) -> Result<(), Box<dyn std::error::Error>> {
rustgate::c2::client::run(
&server_url,
cert.to_str().unwrap_or_default(),
key.to_str().unwrap_or_default(),
ca_cert.to_str().unwrap_or_default(),
)
.await?;
Ok(())
}
async fn run_gen_client_cert(
cn: String,
out_dir: PathBuf,
ca_dir: PathBuf,
) -> Result<(), Box<dyn std::error::Error>> {
let ca = CertificateAuthority::with_dir(ca_dir).await?;
if cn.contains('/') || cn.contains('\\') || cn.contains('\0') || cn.starts_with('.') {
return Err("CN must not contain path separators or start with '.'".into());
}
let (cert_pem, key_pem) = ca.generate_client_cert(&cn)?;
tokio::fs::create_dir_all(&out_dir).await?;
let cert_path = out_dir.join(format!("{cn}.pem"));
let key_path = out_dir.join(format!("{cn}-key.pem"));
#[cfg(unix)]
{
for path in [&cert_path, &key_path] {
if let Ok(meta) = tokio::fs::symlink_metadata(path).await {
if meta.file_type().is_symlink() {
return Err(format!("Refusing to overwrite symlink: {}", path.display()).into());
}
}
}
}
tokio::fs::write(&cert_path, &cert_pem).await?;
#[cfg(unix)]
{
use tokio::io::AsyncWriteExt;
let new_file = tokio::fs::OpenOptions::new()
.write(true)
.create_new(true)
.mode(0o600)
.open(&key_path)
.await;
if let Ok(f) = new_file {
let mut writer = tokio::io::BufWriter::new(f);
writer.write_all(key_pem.as_bytes()).await?;
writer.flush().await?;
} else {
tokio::fs::write(&key_path, &key_pem).await?;
use std::os::unix::fs::PermissionsExt;
tokio::fs::set_permissions(&key_path, std::fs::Permissions::from_mode(0o600)).await?;
}
}
#[cfg(not(unix))]
{
tokio::fs::write(&key_path, &key_pem).await?;
}
info!("Client certificate generated:");
info!(" Cert: {}", cert_path.display());
info!(" Key: {}", key_path.display());
info!(" CN: {cn}");
Ok(())
}