ares-api 0.4.0

HTTP server for Ares AI scraper
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;

use axum::http::HeaderValue;
use tokio::net::TcpListener;
use tower_governor::{GovernorLayer, governor::GovernorConfigBuilder};
use tower_http::cors::{AllowOrigin, CorsLayer};
use tower_http::limit::RequestBodyLimitLayer;
use tower_http::trace::TraceLayer;
use tracing_subscriber::EnvFilter;

use ares_api::routes;
use ares_api::state::AppState;
use ares_core::proxy::{ProxyConfig, ProxyEntry, RotationStrategy, TlsBackend};
use ares_db::{Database, DatabaseConfig};

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    let _ = dotenvy::dotenv();

    tracing_subscriber::fmt()
        .with_env_filter(EnvFilter::from_default_env().add_directive("ares=info".parse()?))
        .with_target(false)
        .init();

    let admin_token = std::env::var("ARES_ADMIN_TOKEN").ok();
    let port = std::env::var("ARES_SERVER_PORT").unwrap_or_else(|_| "3000".to_string());
    let addr = format!("0.0.0.0:{port}");
    let schemas_dir =
        PathBuf::from(std::env::var("ARES_SCHEMAS_DIR").unwrap_or_else(|_| "schemas".to_string()));

    let db = Database::connect(&DatabaseConfig::from_env()?).await?;
    db.migrate().await?;

    if admin_token.is_some() {
        tracing::info!("Admin authentication: enabled");
    } else {
        tracing::info!("Admin authentication: disabled (set ARES_ADMIN_TOKEN to enable)");
    }
    tracing::info!("Schemas directory: {}", schemas_dir.display());

    // -- Proxy / UA rotation (server-level) --
    let proxy_config = build_proxy_config()?;
    let random_ua = std::env::var("ARES_RANDOM_UA")
        .map(|v| v == "true" || v == "1")
        .unwrap_or(false);
    let browser = std::env::var("ARES_BROWSER")
        .map(|v| v == "true" || v == "1")
        .unwrap_or(false);
    let stealth = std::env::var("ARES_STEALTH")
        .map(|v| v == "true" || v == "1")
        .unwrap_or(false);
    let tls_backend: TlsBackend = std::env::var("ARES_TLS_BACKEND")
        .unwrap_or_else(|_| "rustls".to_string())
        .parse()
        .map_err(|e: String| anyhow::anyhow!("{e}"))?;

    if proxy_config.is_some() {
        tracing::info!("Proxy rotation: enabled");
    }
    if random_ua {
        tracing::info!("User-Agent rotation: enabled");
    }
    if browser {
        tracing::info!("Browser mode: enabled");
    }
    if stealth {
        tracing::info!("Browser stealth: enabled");
    }
    if !matches!(tls_backend, TlsBackend::Rustls) {
        tracing::info!("TLS backend: {tls_backend}");
    }

    let state = Arc::new(AppState {
        db,
        admin_token,
        schemas_dir,
        proxy_config,
        random_ua,
        browser,
        stealth,
        tls_backend,
    });

    // -- Rate limiting (per-IP) --
    let burst_size = env_parse("ARES_RATE_LIMIT_BURST", 30);
    let per_second = env_parse("ARES_RATE_LIMIT_RPS", 1);
    let body_limit = env_parse("ARES_BODY_SIZE_LIMIT", 2 * 1024 * 1024); // 2 MB

    let governor_conf = GovernorConfigBuilder::default()
        .per_second(per_second)
        .burst_size(burst_size)
        .finish()
        .expect("Invalid rate limit configuration");

    tracing::info!(burst_size, per_second, body_limit, "Rate limiting: enabled");

    // Background task to clean up stale rate-limit entries
    let governor_limiter = governor_conf.limiter().clone();
    tokio::spawn(async move {
        loop {
            tokio::time::sleep(Duration::from_secs(60)).await;
            tracing::debug!(
                "Rate limiter storage size: {} (cleaning up)",
                governor_limiter.len()
            );
            governor_limiter.retain_recent();
        }
    });

    // -- CORS --
    let cors = match std::env::var("ARES_CORS_ORIGIN") {
        Ok(origin) if origin == "*" => CorsLayer::permissive(),
        Ok(origin) => {
            let origins: Vec<HeaderValue> = origin
                .split(',')
                .filter_map(|o| o.trim().parse().ok())
                .collect();
            CorsLayer::new().allow_origin(AllowOrigin::list(origins))
        }
        Err(_) => CorsLayer::new(),
    };

    let app = routes::router(state)
        .layer(GovernorLayer::new(governor_conf))
        .layer(RequestBodyLimitLayer::new(body_limit))
        .layer(TraceLayer::new_for_http())
        .layer(cors);

    tracing::info!("Starting server on {addr}");
    let listener = TcpListener::bind(&addr).await?;
    axum::serve(
        listener,
        app.into_make_service_with_connect_info::<SocketAddr>(),
    )
    .with_graceful_shutdown(shutdown_signal())
    .await?;

    Ok(())
}

/// Build a `ProxyConfig` from `ARES_PROXY` and/or `ARES_PROXY_FILE` env vars.
fn build_proxy_config() -> anyhow::Result<Option<ProxyConfig>> {
    let rotation: RotationStrategy = std::env::var("ARES_PROXY_ROTATION")
        .unwrap_or_else(|_| "round-robin".to_string())
        .parse()
        .map_err(|e: String| anyhow::anyhow!("{e}"))?;

    let mut entries: Vec<ProxyEntry> = Vec::new();

    if let Ok(url) = std::env::var("ARES_PROXY") {
        entries.push(ProxyEntry::new(url));
    }

    if let Ok(path) = std::env::var("ARES_PROXY_FILE") {
        let content = std::fs::read_to_string(&path)
            .map_err(|e| anyhow::anyhow!("Failed to read ARES_PROXY_FILE '{path}': {e}"))?;
        for line in content.lines().map(str::trim) {
            if !line.is_empty() && !line.starts_with('#') {
                entries.push(ProxyEntry::new(line));
            }
        }
    }

    if entries.is_empty() {
        return Ok(None);
    }

    Ok(Some(ProxyConfig::new(entries, rotation)))
}

/// Parse an env var as a numeric type, falling back to a default.
fn env_parse<T: std::str::FromStr>(var: &str, default: T) -> T {
    std::env::var(var)
        .ok()
        .and_then(|v| v.parse().ok())
        .unwrap_or(default)
}

async fn shutdown_signal() {
    tokio::signal::ctrl_c()
        .await
        .expect("Failed to install CTRL+C handler");
    tracing::info!("Shutdown signal received");
}