mod registry;
use std::process::{Command, Stdio};
use std::time::Duration;
pub use registry::{DaemonInfo, DaemonRole};
const DAEMON_ENV_VAR: &str = "__WSPROXY_DAEMON_CHILD";
const MONITOR_STDIN_VAR: &str = "__WSPROXY_MONITOR_STDIN";
const DAEMON_ID_VAR: &str = "__WSPROXY_DAEMON_ID";
pub fn is_daemon_child() -> bool {
std::env::var(DAEMON_ENV_VAR).is_ok()
}
pub fn should_monitor_stdin() -> bool {
std::env::var(MONITOR_STDIN_VAR).is_ok()
}
fn get_daemon_id() -> Option<u32> {
std::env::var(DAEMON_ID_VAR).ok()?.parse().ok()
}
pub fn run_restart_loop() -> ! {
const MIN_BACKOFF_MS: u64 = 1;
const MAX_BACKOFF_MS: u64 = 5 * 60 * 1000;
let args: Vec<String> = std::env::args().collect();
let mut child_args: Vec<String> = Vec::new();
let mut found_daemon = false;
for arg in &args {
if found_daemon {
child_args.push(arg.clone());
} else if arg == "daemon" {
found_daemon = true;
} else {
child_args.push(arg.clone());
}
}
let role = if child_args.get(1).map(|s| s.as_str()) == Some("client") {
"client"
} else {
"server"
};
let daemon_id = get_daemon_id();
ctrlc::set_handler(move || {
if let Some(id) = daemon_id {
registry::unregister(id).ok();
}
std::process::exit(0);
})
.ok();
let mut backoff_ms = MIN_BACKOFF_MS;
loop {
eprintln!("Starting wsproxy {}...", role);
let mut child = match Command::new(&child_args[0])
.args(&child_args[1..])
.env_remove(DAEMON_ENV_VAR)
.env_remove(DAEMON_ID_VAR)
.env(MONITOR_STDIN_VAR, "1")
.stdin(Stdio::piped())
.spawn()
{
Ok(child) => child,
Err(e) => {
eprintln!("Failed to start wsproxy {}: {}", role, e);
eprintln!("Restarting in {} ms...", backoff_ms);
std::thread::sleep(Duration::from_millis(backoff_ms));
backoff_ms = (backoff_ms * 2).min(MAX_BACKOFF_MS);
continue;
}
};
let _stdin_handle = child.stdin.take();
let status = child.wait();
match status {
Ok(status) if status.success() => {
backoff_ms = MIN_BACKOFF_MS;
eprintln!("wsproxy {} exited successfully", role);
}
Ok(status) => {
eprintln!("wsproxy {} exited with status: {}", role, status);
}
Err(e) => {
eprintln!("Failed to wait for wsproxy {}: {}", role, e);
}
}
eprintln!("Restarting in {} ms...", backoff_ms);
std::thread::sleep(Duration::from_millis(backoff_ms));
backoff_ms = (backoff_ms * 2).min(MAX_BACKOFF_MS);
}
}
pub fn spawn_server(
config: Option<String>,
listen: Option<String>,
route: Vec<String>,
default_target: Option<String>,
tls_cert: Option<String>,
tls_key: Option<String>,
tls_self_signed: bool,
) -> wsproxy::Result<()> {
let mut args = vec!["server".to_string()];
if let Some(config_path) = &config {
args.push("--config".to_string());
args.push(config_path.clone());
} else {
if let Some(listen_addr) = &listen {
args.push("--listen".to_string());
args.push(listen_addr.clone());
}
for r in &route {
args.push("--route".to_string());
args.push(r.clone());
}
if let Some(target) = &default_target {
args.push("--default-target".to_string());
args.push(target.clone());
}
if let Some(cert) = &tls_cert {
args.push("--tls-cert".to_string());
args.push(cert.clone());
}
if let Some(key) = &tls_key {
args.push("--tls-key".to_string());
args.push(key.clone());
}
if tls_self_signed {
args.push("--tls-self-signed".to_string());
}
}
spawn_daemon(DaemonRole::Server, args)
}
pub fn spawn_client(
listen: String,
server: String,
insecure: bool,
tls_ca_cert: Option<String>,
) -> wsproxy::Result<()> {
let mut args = vec![
"client".to_string(),
"--listen".to_string(),
listen,
"--server".to_string(),
server,
];
if insecure {
args.push("--insecure".to_string());
}
if let Some(ca_cert) = &tls_ca_cert {
args.push("--tls-ca-cert".to_string());
args.push(ca_cert.clone());
}
spawn_daemon(DaemonRole::Client, args)
}
fn spawn_daemon(role: DaemonRole, args: Vec<String>) -> wsproxy::Result<()> {
let exe = std::env::current_exe()
.map_err(|e| wsproxy::Error::config(format!("Failed to get current executable: {}", e)))?;
let id = {
let _lock = registry::FileLock::acquire()
.map_err(|e| wsproxy::Error::config(format!("Failed to acquire lock: {}", e)))?;
let daemons = registry::read();
daemons.iter().map(|d| d.id).max().unwrap_or(0) + 1
};
let mut cmd = Command::new(&exe);
cmd.arg("daemon");
cmd.args(&args);
cmd.env(DAEMON_ENV_VAR, "1");
cmd.env(DAEMON_ID_VAR, id.to_string());
cmd.stdin(Stdio::null());
cmd.stdout(Stdio::null());
cmd.stderr(Stdio::inherit());
let child = cmd
.spawn()
.map_err(|e| wsproxy::Error::config(format!("Failed to spawn daemon process: {}", e)))?;
let pid = child.id();
{
let _lock = registry::FileLock::acquire()
.map_err(|e| wsproxy::Error::config(format!("Failed to acquire lock: {}", e)))?;
let mut daemons = registry::read();
let started_at = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
daemons.push(registry::DaemonInfo {
id,
pid,
role,
args: args.clone(),
started_at,
});
registry::write(&daemons)
.map_err(|e| wsproxy::Error::config(format!("Failed to write registry: {}", e)))?;
}
eprintln!("Daemon started with ID {} (PID {})", id, pid);
Ok(())
}
pub fn list() -> wsproxy::Result<Vec<DaemonInfo>> {
let _lock = registry::FileLock::acquire()
.map_err(|e| wsproxy::Error::config(format!("Failed to acquire lock: {}", e)))?;
let daemons = registry::read();
let (alive, _dead): (Vec<_>, Vec<_>) = daemons
.into_iter()
.partition(|d| registry::is_process_alive(d.pid));
registry::write(&alive)
.map_err(|e| wsproxy::Error::config(format!("Failed to write registry: {}", e)))?;
Ok(alive)
}
pub fn kill(id: u32) -> wsproxy::Result<bool> {
let _lock = registry::FileLock::acquire()
.map_err(|e| wsproxy::Error::config(format!("Failed to acquire lock: {}", e)))?;
let mut daemons = registry::read();
if let Some(pos) = daemons.iter().position(|d| d.id == id) {
let daemon = &daemons[pos];
let killed = registry::kill_process(daemon.pid);
if killed {
daemons.remove(pos);
registry::write(&daemons)
.map_err(|e| wsproxy::Error::config(format!("Failed to write registry: {}", e)))?;
}
Ok(killed)
} else {
Ok(false)
}
}
pub async fn wait_for_stdin_close() {
use tokio::io::AsyncReadExt;
let mut stdin = tokio::io::stdin();
let mut buf = [0u8; 1];
let _ = stdin.read(&mut buf).await;
}