#![recursion_limit="512"]
#![warn(rust_2018_idioms)]
use chrono::Local;
use std::sync::Arc;
use std::collections::HashMap;
use std::path::PathBuf;
use std::path::Path;
use log::*;
use argh::FromArgs;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
mod tcp;
mod udp;
mod dns;
use rproxy::{Proxy, Settings};
fn pid_file_path() -> PathBuf {
std::env::temp_dir().join("rproxy.pid")
}
fn reload_signal_path() -> PathBuf {
std::env::temp_dir().join("rproxy.reload")
}
#[allow(dead_code)]
#[derive(FromArgs)]
#[argh(description = "rproxy is a platform independent UDP TCP high performance async proxy")]
struct Options {
#[argh(option, short='r', default="\"\".to_string()")]
remote: String,
#[argh(option, short='b', default="\"\".to_string()")]
bind: String,
#[argh(option, short='p', default="\"UDP\".to_string()")]
protocol: String,
#[argh(switch, short='d')]
debug: bool,
#[argh(option, short = 'l', default = "\"rproxy.logger.yaml\".to_string()")]
logger_settings: String,
#[argh(option, short = 'c')]
config: Option<PathBuf>,
#[argh(option, short = 's')]
signal: Option<String>,
#[argh(option)]
max_connections: Option<usize>,
#[argh(option)]
max_client_tunnels: Option<usize>,
#[argh(option)]
keepalive_idle: Option<u64>,
#[argh(option)]
keepalive_interval: Option<u64>,
}
static MY_LOGGER: MyLogger = MyLogger;
struct MyLogger;
impl log::Log for MyLogger {
fn enabled(&self, _metadata: &log::Metadata<'_>) -> bool {
true
}
fn log(&self, record: &log::Record<'_>) {
if self.enabled(record.metadata()) {
println!("[{}][{}] - {}", record.level(), Local::now(), record.args());
}
}
fn flush(&self) {}
}
fn build_settings(options: &Options, base: Settings) -> Settings {
Settings {
max_connections: options.max_connections.unwrap_or(base.max_connections),
max_client_tunnels: options.max_client_tunnels.unwrap_or(base.max_client_tunnels),
keepalive_idle: options.keepalive_idle.unwrap_or(base.keepalive_idle),
keepalive_interval: options.keepalive_interval.unwrap_or(base.keepalive_interval),
}
}
fn write_pid_file() {
let path = pid_file_path();
let pid = std::process::id();
if let Err(e) = std::fs::write(&path, pid.to_string()) {
warn!("Failed to write PID file {}: {}", path.display(), e);
} else {
info!("PID {} written to {}", pid, path.display());
}
}
fn remove_pid_file() {
let _ = std::fs::remove_file(pid_file_path());
let _ = std::fs::remove_file(reload_signal_path());
}
fn send_reload_signal() {
let pid_path = pid_file_path();
match std::fs::read_to_string(&pid_path) {
Ok(content) => {
match content.trim().parse::<u32>() {
Ok(pid) => {
#[cfg(unix)]
{
use nix::sys::signal::{kill, Signal};
use nix::unistd::Pid;
match kill(Pid::from_raw(pid as i32), Signal::SIGHUP) {
Ok(()) => {
println!("Reload signal sent to rproxy (PID {})", pid);
return;
}
Err(e) => {
eprintln!("Failed to send signal to PID {}: {}", pid, e);
std::process::exit(1);
}
}
}
#[cfg(not(unix))]
{
let signal_path = reload_signal_path();
match std::fs::write(&signal_path, "reload") {
Ok(()) => {
println!("Reload signal written for rproxy (PID {})", pid);
}
Err(e) => {
eprintln!("Failed to write reload signal file: {}", e);
std::process::exit(1);
}
}
}
}
Err(e) => {
eprintln!("Invalid PID in {}: {}", pid_path.display(), e);
std::process::exit(1);
}
}
}
Err(e) => {
eprintln!("Failed to read PID file {}: {}. Is rproxy running?", pid_path.display(), e);
std::process::exit(1);
}
}
}
struct ProxyTask {
#[allow(dead_code)]
handle: JoinHandle<Result<(), std::io::Error>>,
cancel: CancellationToken,
proxy: Proxy,
}
fn spawn_proxy(proxy: &Proxy, settings: &Arc<Settings>) -> ProxyTask {
let bind = proxy.bind.clone();
let remote = proxy.remote.clone();
let protocol = proxy.protocol.clone();
let s = settings.clone();
let cancel = CancellationToken::new();
let c = cancel.clone();
let handle = if protocol == "UDP" {
tokio::spawn(async move { udp::udp_proxy(&bind, &remote, s, c).await })
} else {
tokio::spawn(async move { tcp::tcp_proxy(&bind, &remote, s, c).await })
};
ProxyTask {
handle,
cancel,
proxy: proxy.clone(),
}
}
fn reload_config(
config_path: &Path,
options: &Options,
running: &mut HashMap<String, ProxyTask>,
) {
let config = match rproxy::load_config(config_path) {
Ok(c) => c,
Err(e) => {
error!("Failed to reload configuration: {:?}", e);
return;
}
};
let settings = Arc::new(build_settings(options, config.settings));
let mut new_proxies: HashMap<String, &Proxy> = HashMap::new();
for proxy in &config.proxies {
new_proxies.insert(proxy.bind.clone(), proxy);
}
let old_keys: Vec<String> = running.keys().cloned().collect();
for key in &old_keys {
match new_proxies.get(key) {
None => {
info!("[reload] Stopping proxy on {}", key);
if let Some(task) = running.remove(key) {
task.cancel.cancel();
}
}
Some(new_proxy) => {
let old_task = running.get(key).unwrap();
if old_task.proxy.remote != new_proxy.remote || old_task.proxy.protocol != new_proxy.protocol {
info!("[reload] Restarting proxy on {} ({}->{})",
key, old_task.proxy.remote, new_proxy.remote);
if let Some(task) = running.remove(key) {
task.cancel.cancel();
}
} else {
debug!("[reload] Proxy on {} unchanged, keeping", key);
}
}
}
}
for (bind, proxy) in &new_proxies {
if !running.contains_key(bind) {
info!("[reload] Starting proxy on {} -> {} ({})", bind, proxy.remote, proxy.protocol);
let task = spawn_proxy(proxy, &settings);
running.insert(bind.clone(), task);
}
}
info!("[reload] Configuration reloaded: {} proxies active", running.len());
}
async fn wait_for_reload_signal() {
#[cfg(unix)]
{
let mut sighup = tokio::signal::unix::signal(
tokio::signal::unix::SignalKind::hangup()
).expect("Failed to register SIGHUP handler");
sighup.recv().await;
}
#[cfg(not(unix))]
{
let signal_path = reload_signal_path();
loop {
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
if signal_path.exists() {
let _ = std::fs::remove_file(&signal_path);
break;
}
}
}
}
#[tokio::main]
async fn main(){
let options: Options = argh::from_env();
if let Some(ref sig) = options.signal {
match sig.as_str() {
"reload" => {
send_reload_signal();
return;
}
_ => {
eprintln!("Unknown signal: {}. Supported: reload", sig);
std::process::exit(1);
}
}
}
if Path::new(&options.logger_settings).exists(){
log4rs::init_file(&options.logger_settings,
Default::default()).unwrap();
debug!("NICE");
} else {
log::set_logger(&MY_LOGGER).unwrap();
if options.debug {
log::set_max_level(log::LevelFilter::Debug);
} else {
log::set_max_level(log::LevelFilter::Info);
}
}
match options.config {
None => {
let settings = Arc::new(build_settings(&options, Settings::default()));
let cancel = CancellationToken::new();
if options.protocol == "UDP"{
udp::udp_proxy(&options.bind, &options.remote, settings, cancel).await.unwrap();
} else if options.protocol == "TCP" {
tcp::tcp_proxy(&options.bind, &options.remote, settings, cancel).await.unwrap();
}
},
Some(ref config_path) => {
if !config_path.as_path().exists() {
error!("Invalid configuration file path {}", config_path.as_path().display());
return;
}
write_pid_file();
let config = match rproxy::load_config(config_path.as_path()) {
Ok(c) => c,
Err(e) => {
error!("Failed to parse configuration: {:?}", e);
remove_pid_file();
return;
}
};
let settings = Arc::new(build_settings(&options, config.settings));
info!("Settings: max_connections={}, max_client_tunnels={}, keepalive_idle={}s, keepalive_interval={}s",
settings.max_connections, settings.max_client_tunnels,
settings.keepalive_idle, settings.keepalive_interval);
let mut running: HashMap<String, ProxyTask> = HashMap::new();
for proxy in &config.proxies {
let task = spawn_proxy(proxy, &settings);
running.insert(proxy.bind.clone(), task);
}
info!("Started {} proxies, listening for reload signal", running.len());
loop {
tokio::select! {
_ = wait_for_reload_signal() => {
info!("[reload] Reload signal received, reloading configuration from {}", config_path.display());
reload_config(config_path.as_path(), &options, &mut running);
},
_ = tokio::signal::ctrl_c() => {
info!("Ctrl+C received, shutting down all proxies...");
for (bind, task) in running.drain() {
info!("Stopping proxy on {}", bind);
task.cancel.cancel();
}
break;
}
}
}
remove_pid_file();
}
}
}