#![forbid(unsafe_code)]
use clap::Parser;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper_util::rt::TokioIo;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener;
use tracing::{debug, error, info, warn};
use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberInitExt};
use wisegate::args::Args;
use wisegate::config::EnvVarConfig;
use wisegate::connection::{ConnectionLimiter, ConnectionTracker};
use wisegate::server::StartupConfig;
use wisegate::{ConnectionProvider, ProxyProvider, RateLimiter};
use wisegate::{request_handler, server};
const SHUTDOWN_TIMEOUT_SECS: u64 = 30;
fn init_tracing(verbose: bool, quiet: bool, json_logs: bool) {
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| {
if quiet {
EnvFilter::new("error")
} else if verbose {
EnvFilter::new("debug")
} else {
EnvFilter::new("info")
}
});
if json_logs {
tracing_subscriber::registry()
.with(env_filter)
.with(fmt::layer().json())
.init();
} else {
tracing_subscriber::registry()
.with(env_filter)
.with(fmt::layer().with_target(false))
.init();
}
}
#[tokio::main]
async fn main() {
let args = Args::parse();
let bind_ip = match args.validate() {
Ok(ip) => ip,
Err(err) => {
eprintln!("Configuration error: {err}");
std::process::exit(1);
}
};
init_tracing(args.verbose, args.quiet, args.json_logs);
let rate_limiter = RateLimiter::new();
let env_config = Arc::new(EnvVarConfig::new());
let startup_config = StartupConfig {
listen_port: args.listen,
forward_port: args.forward,
bind_address: args.bind.clone(),
verbose: args.verbose,
quiet: args.quiet,
};
server::print_startup_info(&startup_config, &*env_config);
let http_client = reqwest::Client::builder()
.timeout(env_config.proxy_config().timeout)
.pool_max_idle_per_host(32)
.build()
.unwrap_or_else(|_| reqwest::Client::new());
let bind_addr = SocketAddr::from((bind_ip, args.listen));
let listener = match TcpListener::bind(bind_addr).await {
Ok(listener) => listener,
Err(err) => {
error!(port = args.listen, error = %err, "Failed to bind to port");
std::process::exit(1);
}
};
let max_connections = env_config.max_connections();
let connection_limiter = ConnectionLimiter::new(max_connections);
if connection_limiter.is_enabled() {
info!(
max_connections = max_connections,
"Connection limit configured"
);
} else {
warn!("No connection limit configured (MAX_CONNECTIONS=0)");
}
info!(port = args.listen, bind = %args.bind, "WiseGate is running");
let forward_host: Arc<str> = Arc::from(args.bind.as_str());
let connection_tracker = ConnectionTracker::new();
loop {
tokio::select! {
accept_result = listener.accept() => {
let (stream, addr) = match accept_result {
Ok(conn) => conn,
Err(err) => {
warn!(error = %err, "Failed to accept connection");
continue;
}
};
let permit = if connection_limiter.is_enabled() {
match connection_limiter.try_acquire() {
Some(permit) => Some(permit),
None => {
warn!(client = %addr, max = max_connections, "Connection rejected: server at capacity");
drop(stream);
continue;
}
}
} else {
None
};
debug!(client = %addr, "New connection");
let io = TokioIo::new(stream);
let limiter = rate_limiter.clone();
let forward_host = forward_host.clone();
let forward_port = args.forward;
let tracker = connection_tracker.clone();
let config = env_config.clone();
let client = http_client.clone();
tokio::task::spawn(async move {
let _permit = permit;
let _conn_guard = tracker.track();
let service = service_fn(move |req| {
request_handler::handle_request(req, forward_host.clone(), forward_port, limiter.clone(), config.clone(), client.clone())
});
if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
warn!(client = %addr, error = %err, "Connection error");
}
});
}
_ = shutdown_signal() => {
info!("Shutdown signal received, stopping gracefully...");
break;
}
}
}
let active = connection_tracker.count();
if active > 0 {
info!(
active_connections = active,
"Waiting for connections to finish..."
);
let timeout = Duration::from_secs(SHUTDOWN_TIMEOUT_SECS);
if !connection_tracker.wait_for_shutdown(timeout).await {
let remaining = connection_tracker.count();
warn!(
remaining_connections = remaining,
"Timeout reached, forcing shutdown"
);
}
}
info!("WiseGate stopped cleanly");
}
async fn shutdown_signal() {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("Failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("Failed to install SIGTERM handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {}
_ = terminate => {}
}
}