codive-relay 0.1.0

Relay server for secure tunneling
Documentation
//! Relay server binary
//!
//! This is the public relay server that enables remote access to local Codive
//! servers through secure tunnels.
//!
//! # Usage
//!
//! ```bash
//! # Start with defaults (localhost:3001)
//! codive-relay
//!
//! # Specify host and port
//! codive-relay --host 0.0.0.0 --port 8080
//!
//! # Specify base domain for tunnel URLs
//! codive-relay --domain relay.example.com
//! ```

use anyhow::Result;
use clap::Parser;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener;
use tracing::{info, Level};
use tracing_subscriber::EnvFilter;

use codive_relay::{create_relay_app, AuthRateLimitConfig, RelayConfig, RelayState};

#[derive(Parser, Debug)]
#[command(name = "codive-relay")]
#[command(about = "Relay server for secure Codive tunneling")]
#[command(version)]
struct Cli {
    /// Host to bind to
    #[arg(long, default_value = "127.0.0.1")]
    host: String,

    /// Port to bind to
    #[arg(long, default_value = "3001")]
    port: u16,

    /// Base domain for tunnel URLs (e.g., relay.example.com)
    #[arg(long, default_value = "localhost:3001")]
    domain: String,

    /// Use HTTPS for tunnel URLs
    #[arg(long)]
    https: bool,

    /// Request timeout in seconds
    #[arg(long, default_value = "30")]
    timeout: u64,

    /// Maximum tunnels per IP address
    #[arg(long, default_value = "10")]
    max_tunnels_per_ip: usize,

    /// Require authentication (API tokens) for tunnel connections
    #[arg(long)]
    require_auth: bool,

    /// Valid authentication tokens (can be specified multiple times)
    #[arg(long = "auth-token", action = clap::ArgAction::Append)]
    auth_tokens: Vec<String>,

    /// JWT secret for token-based authentication (enables JWT mode)
    #[arg(long)]
    jwt_secret: Option<String>,

    /// JWT token validity in seconds (default: 3600 = 1 hour)
    #[arg(long, default_value = "3600")]
    jwt_validity: u64,

    /// Maximum failed auth attempts before temporary ban (default: 5)
    #[arg(long, default_value = "5")]
    max_auth_failures: u32,

    /// Auth ban duration in seconds (default: 300 = 5 minutes)
    #[arg(long, default_value = "300")]
    auth_ban_duration: u64,

    /// Maximum tunnel age in seconds (TTL). Tunnels older than this are closed.
    /// Use 0 for no limit. Default: 0 (no limit)
    #[arg(long, default_value = "0")]
    max_tunnel_age: u64,

    /// Maximum idle time in seconds. Tunnels with no activity are closed.
    /// Use 0 for no limit. Default: 0 (no limit)
    #[arg(long, default_value = "0")]
    max_idle_time: u64,

    /// Disable custom tunnel IDs (force random IDs only).
    /// Recommended for public relays to prevent subdomain squatting.
    #[arg(long)]
    random_ids_only: bool,

    /// Verbosity level (0=warn, 1=info, 2=debug, 3+=trace)
    #[arg(short, long, action = clap::ArgAction::Count)]
    verbose: u8,
}

#[tokio::main]
async fn main() -> Result<()> {
    let cli = Cli::parse();

    // Initialize logging
    let log_level = match cli.verbose {
        0 => Level::WARN,
        1 => Level::INFO,
        2 => Level::DEBUG,
        _ => Level::TRACE,
    };

    tracing_subscriber::fmt()
        .with_env_filter(
            EnvFilter::from_default_env().add_directive(log_level.into()),
        )
        .init();

    // Build configuration
    let listen_addr: SocketAddr = format!("{}:{}", cli.host, cli.port).parse()?;

    let config = RelayConfig {
        base_domain: cli.domain.clone(),
        listen_addr,
        request_timeout: Duration::from_secs(cli.timeout),
        max_tunnels_per_ip: cli.max_tunnels_per_ip,
        use_https: cli.https,
        auth_tokens: cli.auth_tokens.into_iter().collect(),
        require_auth: cli.require_auth,
        jwt_secret: cli.jwt_secret.map(|s| s.into_bytes()),
        jwt_validity: Duration::from_secs(cli.jwt_validity),
        auth_rate_limit: AuthRateLimitConfig {
            max_failed_attempts: cli.max_auth_failures,
            ban_duration: Duration::from_secs(cli.auth_ban_duration),
            attempt_window: Duration::from_secs(60),
        },
        max_tunnel_age: if cli.max_tunnel_age > 0 {
            Some(Duration::from_secs(cli.max_tunnel_age))
        } else {
            None
        },
        max_idle_time: if cli.max_idle_time > 0 {
            Some(Duration::from_secs(cli.max_idle_time))
        } else {
            None
        },
        allow_custom_ids: !cli.random_ids_only,
    };

    // Capture values for banner before they're moved
    let https = config.use_https;
    let jwt_enabled = config.jwt_secret.is_some();
    let require_auth = config.require_auth;
    let domain = config.base_domain.clone();
    let max_auth_failures = config.auth_rate_limit.max_failed_attempts;
    let ban_duration = config.auth_rate_limit.ban_duration.as_secs();
    let max_tunnel_age = config.max_tunnel_age;
    let max_idle_time = config.max_idle_time;
    let allow_custom_ids = config.allow_custom_ids;
    let has_ttl = max_tunnel_age.is_some() || max_idle_time.is_some();

    // Create state
    let state = Arc::new(RelayState::new(config));

    // Create app
    let app = create_relay_app(state.clone());

    // Bind listener
    let listener = TcpListener::bind(listen_addr).await?;

    println!();
    println!("Agent Relay Server");
    println!("==================");
    println!();
    println!("Listening on: http://{}", listen_addr);
    println!("Base domain:  {}", domain);
    println!();
    println!("Security:");
    println!("  Auth required:     {}", if require_auth { "Yes" } else { "No" });
    println!("  JWT enabled:       {}", if jwt_enabled { "Yes" } else { "No" });
    println!("  Auth rate limit:   {} failures -> {} sec ban", max_auth_failures, ban_duration);
    println!("  Security headers:  HSTS, CSP, X-Frame-Options, X-XSS-Protection");
    println!();
    println!("Tunnel Limits:");
    println!("  Custom IDs:        {}", if allow_custom_ids { "Allowed" } else { "Random only" });
    println!("  Max age (TTL):     {}", max_tunnel_age.map(|d| format!("{} sec", d.as_secs())).unwrap_or_else(|| "No limit".to_string()));
    println!("  Max idle time:     {}", max_idle_time.map(|d| format!("{} sec", d.as_secs())).unwrap_or_else(|| "No limit".to_string()));
    println!();
    println!("Endpoints:");
    println!("  GET  /health         - Health check");
    println!("  GET  /agent          - WebSocket for agent connections");
    println!("  *    /*              - Proxy to tunnels (by subdomain)");
    println!();
    println!("Tunnel URL format: {}://<tunnel_id>.{}",
        if https { "https" } else { "http" },
        domain
    );
    println!();

    // Spawn cleanup task if TTL is configured
    if has_ttl {
        let cleanup_state = state.clone();
        tokio::spawn(async move {
            let mut interval = tokio::time::interval(Duration::from_secs(30));
            loop {
                interval.tick().await;
                let removed = cleanup_state.cleanup_expired_tunnels().await;
                if removed > 0 {
                    info!(removed, total = cleanup_state.tunnel_count(), "Cleaned up expired tunnels");
                }
            }
        });
        info!("Tunnel cleanup task started (runs every 30 seconds)");
    }

    // Run server
    info!("Starting relay server on {}", listen_addr);
    axum::serve(
        listener,
        app.into_make_service_with_connect_info::<SocketAddr>(),
    )
    .with_graceful_shutdown(shutdown_signal())
    .await?;

    info!("Relay server stopped");
    Ok(())
}

/// Wait for shutdown signal
async fn shutdown_signal() {
    let ctrl_c = async {
        tokio::signal::ctrl_c()
            .await
            .expect("Failed to install Ctrl+C handler");
    };

    #[cfg(unix)]
    let terminate = async {
        tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
            .expect("Failed to install signal handler")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c => info!("Received Ctrl+C"),
        _ = terminate => info!("Received SIGTERM"),
    }
}