use anyhow::Result;
use clap::Parser;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener;
use tracing::{info, Level};
use tracing_subscriber::EnvFilter;
use codive_relay::{create_relay_app, AuthRateLimitConfig, RelayConfig, RelayState};
#[derive(Parser, Debug)]
#[command(name = "codive-relay")]
#[command(about = "Relay server for secure Codive tunneling")]
#[command(version)]
struct Cli {
#[arg(long, default_value = "127.0.0.1")]
host: String,
#[arg(long, default_value = "3001")]
port: u16,
#[arg(long, default_value = "localhost:3001")]
domain: String,
#[arg(long)]
https: bool,
#[arg(long, default_value = "30")]
timeout: u64,
#[arg(long, default_value = "10")]
max_tunnels_per_ip: usize,
#[arg(long)]
require_auth: bool,
#[arg(long = "auth-token", action = clap::ArgAction::Append)]
auth_tokens: Vec<String>,
#[arg(long)]
jwt_secret: Option<String>,
#[arg(long, default_value = "3600")]
jwt_validity: u64,
#[arg(long, default_value = "5")]
max_auth_failures: u32,
#[arg(long, default_value = "300")]
auth_ban_duration: u64,
#[arg(long, default_value = "0")]
max_tunnel_age: u64,
#[arg(long, default_value = "0")]
max_idle_time: u64,
#[arg(long)]
random_ids_only: bool,
#[arg(short, long, action = clap::ArgAction::Count)]
verbose: u8,
}
#[tokio::main]
async fn main() -> Result<()> {
let cli = Cli::parse();
let log_level = match cli.verbose {
0 => Level::WARN,
1 => Level::INFO,
2 => Level::DEBUG,
_ => Level::TRACE,
};
tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::from_default_env().add_directive(log_level.into()),
)
.init();
let listen_addr: SocketAddr = format!("{}:{}", cli.host, cli.port).parse()?;
let config = RelayConfig {
base_domain: cli.domain.clone(),
listen_addr,
request_timeout: Duration::from_secs(cli.timeout),
max_tunnels_per_ip: cli.max_tunnels_per_ip,
use_https: cli.https,
auth_tokens: cli.auth_tokens.into_iter().collect(),
require_auth: cli.require_auth,
jwt_secret: cli.jwt_secret.map(|s| s.into_bytes()),
jwt_validity: Duration::from_secs(cli.jwt_validity),
auth_rate_limit: AuthRateLimitConfig {
max_failed_attempts: cli.max_auth_failures,
ban_duration: Duration::from_secs(cli.auth_ban_duration),
attempt_window: Duration::from_secs(60),
},
max_tunnel_age: if cli.max_tunnel_age > 0 {
Some(Duration::from_secs(cli.max_tunnel_age))
} else {
None
},
max_idle_time: if cli.max_idle_time > 0 {
Some(Duration::from_secs(cli.max_idle_time))
} else {
None
},
allow_custom_ids: !cli.random_ids_only,
};
let https = config.use_https;
let jwt_enabled = config.jwt_secret.is_some();
let require_auth = config.require_auth;
let domain = config.base_domain.clone();
let max_auth_failures = config.auth_rate_limit.max_failed_attempts;
let ban_duration = config.auth_rate_limit.ban_duration.as_secs();
let max_tunnel_age = config.max_tunnel_age;
let max_idle_time = config.max_idle_time;
let allow_custom_ids = config.allow_custom_ids;
let has_ttl = max_tunnel_age.is_some() || max_idle_time.is_some();
let state = Arc::new(RelayState::new(config));
let app = create_relay_app(state.clone());
let listener = TcpListener::bind(listen_addr).await?;
println!();
println!("Agent Relay Server");
println!("==================");
println!();
println!("Listening on: http://{}", listen_addr);
println!("Base domain: {}", domain);
println!();
println!("Security:");
println!(" Auth required: {}", if require_auth { "Yes" } else { "No" });
println!(" JWT enabled: {}", if jwt_enabled { "Yes" } else { "No" });
println!(" Auth rate limit: {} failures -> {} sec ban", max_auth_failures, ban_duration);
println!(" Security headers: HSTS, CSP, X-Frame-Options, X-XSS-Protection");
println!();
println!("Tunnel Limits:");
println!(" Custom IDs: {}", if allow_custom_ids { "Allowed" } else { "Random only" });
println!(" Max age (TTL): {}", max_tunnel_age.map(|d| format!("{} sec", d.as_secs())).unwrap_or_else(|| "No limit".to_string()));
println!(" Max idle time: {}", max_idle_time.map(|d| format!("{} sec", d.as_secs())).unwrap_or_else(|| "No limit".to_string()));
println!();
println!("Endpoints:");
println!(" GET /health - Health check");
println!(" GET /agent - WebSocket for agent connections");
println!(" * /* - Proxy to tunnels (by subdomain)");
println!();
println!("Tunnel URL format: {}://<tunnel_id>.{}",
if https { "https" } else { "http" },
domain
);
println!();
if has_ttl {
let cleanup_state = state.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(30));
loop {
interval.tick().await;
let removed = cleanup_state.cleanup_expired_tunnels().await;
if removed > 0 {
info!(removed, total = cleanup_state.tunnel_count(), "Cleaned up expired tunnels");
}
}
});
info!("Tunnel cleanup task started (runs every 30 seconds)");
}
info!("Starting relay server on {}", listen_addr);
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(shutdown_signal())
.await?;
info!("Relay server stopped");
Ok(())
}
async fn shutdown_signal() {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("Failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("Failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => info!("Received Ctrl+C"),
_ = terminate => info!("Received SIGTERM"),
}
}