athena_rs 0.83.0

Database gateway API
Documentation
//! Athena RS binary.
//!
//! Starts the Actix Web server, wires endpoints, configures CORS and tracing,
//! and exposes convenience endpoints for Scylla demo queries.
//!
//!
use actix_cors::Cors;
use actix_web::body::{BoxBody, EitherBody};
use actix_web::dev::{Service, ServiceResponse};
use actix_web::http::header;
use actix_web::{App, HttpServer, web};
use clap::Parser;
use dotenv::dotenv;
use std::env;
use std::io::Error;
use std::io::ErrorKind::Other;
use std::io::Result as IoResult;
use std::net::{SocketAddr, TcpListener};
use std::path::PathBuf;
use std::thread::available_parallelism;
use std::time::Duration;
use tracing::info;
use tracing_subscriber::EnvFilter;
use tracing_subscriber::fmt::time::ChronoLocal;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;

use socket2::{Domain, Protocol, Socket, TcpKeepalive, Type};

use athena_rs::AppState;
use athena_rs::api::gateway::delete::delete_data;
use athena_rs::api::gateway::fetch::{
    fetch_data_route, gateway_update_route, get_data_route, proxy_fetch_data_route,
};
use athena_rs::api::gateway::insert::insert_data;
use athena_rs::api::gateway::postgrest::{
    postgrest_delete_route, postgrest_get_route, postgrest_patch_route, postgrest_post_route,
};
use athena_rs::api::gateway::query::gateway_query_route;
use athena_rs::api::health::{ping, root};
use athena_rs::api::metrics::prometheus_metrics_stub;
use athena_rs::api::pipelines::run_pipeline;
use athena_rs::api::query::sql::sql_query;
use athena_rs::api::registry::{api_registry, api_registry_by_id};
use athena_rs::api::supabase::ssl_enforcement;
use athena_rs::api::{admin, athena_docs, schema};
use athena_rs::api::{athena_openapi_host, athena_router_registry};
use athena_rs::bootstrap::{Bootstrap, build_shared_state};
use athena_rs::cli::{self, AthenaCli, Command};
use athena_rs::config::{Config, ConfigLocation, DEFAULT_CONFIG_FILE_NAME};
use athena_rs::parser::{parse_secs_or_default, parse_usize};

#[actix_web::main]
async fn main() -> IoResult<()> {
    dotenv().ok();
    let sentry_guard: Option<sentry::ClientInitGuard> = init_sentry();
    init_tracing(sentry_guard.is_some());
    let _sentry_guard = sentry_guard;

    let cli = AthenaCli::parse();
    let config_path = cli.config_path.clone();
    let pipelines_path = cli.pipelines_path.clone();
    let port_override = cli.port;
    let command = cli.command;
    let api_only = cli.api_only;
    #[cfg(feature = "cdc")]
    let cdc_only = cli.cdc_only;

    let config_overridden = config_path.is_some();
    let config_path = config_path
        .clone()
        .unwrap_or_else(|| PathBuf::from(DEFAULT_CONFIG_FILE_NAME));
    let pipelines_path: String = pipelines_path.to_string_lossy().to_string();
    let config: Config = if config_overridden {
        Config::load_from(&config_path).map_err(|err| {
            let attempted_locations = vec![ConfigLocation::new(
                "explicit config path".to_string(),
                config_path.clone(),
            )];
            Error::new(
                Other,
                format!(
                    "Failed to load config '{}': {}. Looked in:\n{}",
                    config_path.display(),
                    err,
                    format_attempted_locations(&attempted_locations)
                ),
            )
        })?
    } else {
        match Config::load_default() {
            Ok(outcome) => {
                if outcome.seeded_default {
                    info!("Seeded default configuration at {}", outcome.path.display());
                }
                outcome.config
            }
            Err(err) => {
                let attempts = err.attempted_locations.clone();
                return Err(Error::new(
                    Other,
                    format!(
                        "Failed to load config '{}': {}. Looked in:\n{}",
                        config_path.display(),
                        err,
                        format_attempted_locations(&attempts)
                    ),
                ));
            }
        }
    };

    match command {
        Some(Command::Server) => {
            let bootstrap = build_shared_state(&config, &pipelines_path)
                .await
                .map_err(|err| Error::new(Other, err.to_string()))?;
            run_server(&config, bootstrap, port_override).await
        }
        Some(Command::Pipeline(args)) => {
            let bootstrap = build_shared_state(&config, &pipelines_path)
                .await
                .map_err(|err| Error::new(Other, err.to_string()))?;
            cli::run_pipeline_command(&bootstrap, args)
                .await
                .map_err(|err| Error::new(Other, err.to_string()))
        }
        Some(Command::Clients { command }) => {
            cli::run_clients_command(command).map_err(|err| Error::new(Other, err.to_string()))
        }
        Some(Command::Fetch(cmd)) => cli::run_fetch_command(cmd)
            .await
            .map_err(|err| Error::new(Other, err.to_string())),
        #[cfg(feature = "cdc")]
        Some(Command::Cdc { command }) => cli::run_cdc_command(&config, command)
            .await
            .map_err(|err| Error::new(Other, err.to_string())),
        Some(Command::Diag) => {
            cli::run_diag_command().map_err(|err| Error::new(Other, err.to_string()))
        }
        Some(Command::Version) => {
            cli::run_version_command();
            Ok(())
        }
        None => {
            #[cfg(feature = "cdc")]
            {
                if cdc_only {
                    let port: u16 = port_override.unwrap_or(4053);
                    athena_rs::cdc::websocket::websocket_server(port)
                        .await
                        .map_err(|err| Error::new(Other, err.to_string()))
                } else if api_only {
                    let bootstrap: Bootstrap = build_shared_state(&config, &pipelines_path)
                        .await
                        .map_err(|err| Error::new(Other, err.to_string()))?;
                    run_server(&config, bootstrap, port_override).await
                } else {
                    Err(Error::new(
                        Other,
                        "No command provided; pass --api-only to boot the API, --cdc-only for the CDC WebSocket server, or specify a subcommand.",
                    ))
                }
            }
            #[cfg(not(feature = "cdc"))]
            {
                if api_only {
                    let bootstrap: Bootstrap = build_shared_state(&config, &pipelines_path)
                        .await
                        .map_err(|err| Error::new(Other, err.to_string()))?;
                    run_server(&config, bootstrap, port_override).await
                } else {
                    Err(Error::new(
                        Other,
                        "No command provided; pass --api-only to boot the API or specify a subcommand.",
                    ))
                }
            }
        }
    }
}

async fn run_server(
    config: &Config,
    bootstrap: Bootstrap,
    port_override: Option<u16>,
) -> IoResult<()> {
    let port: u16 = if let Some(port) = port_override {
        port
    } else {
        config
            .get_api()
            .ok_or("No API port configured")
            .and_then(|port_str| port_str.parse().map_err(|_| "Invalid port number"))
            .expect("Failed to parse API port")
    };

    let keep_alive: Duration = parse_secs_or_default(config.get_http_keep_alive_secs(), 15);
    let client_disconnect_timeout =
        parse_secs_or_default(config.get_client_disconnect_timeout_secs(), 60);
    let client_request_timeout =
        parse_secs_or_default(config.get_client_request_timeout_secs(), 60);
    let worker_count: usize = config
        .get_http_workers()
        .and_then(parse_usize)
        .unwrap_or_else(|| available_parallelism().map(|n| n.get()).unwrap_or(4));
    let max_connections: usize = config
        .get_http_max_connections()
        .and_then(parse_usize)
        .unwrap_or(10_000);
    let backlog: usize = config
        .get_http_backlog()
        .and_then(parse_usize)
        .unwrap_or(2_048);
    let tcp_keepalive: Duration = parse_secs_or_default(config.get_tcp_keepalive_secs(), 75);
    let prometheus_metrics_enabled: bool = config.get_prometheus_metrics_enabled();

    let addr: SocketAddr = SocketAddr::from(([0, 0, 0, 0], port));
    let socket: Socket = Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))?;
    socket.set_nonblocking(true)?;
    socket.set_keepalive(true)?;
    let keepalive_cfg: TcpKeepalive = TcpKeepalive::new().with_time(tcp_keepalive);
    socket.set_tcp_keepalive(&keepalive_cfg)?;
    socket.bind(&addr.into())?;
    let listen_backlog: i32 = backlog.min(i32::MAX as usize) as i32;
    socket.listen(listen_backlog)?;
    let listener: TcpListener = socket.into();

    let app_state: web::Data<AppState> = bootstrap.app_state.clone();

    HttpServer::new(move || {
        let cors: Cors = Cors::default()
            .allow_any_origin()
            .allow_any_method()
            .allow_any_header();
        let mut app = App::new()
            .wrap(cors)
            .wrap_fn(|req, srv| {
                let fut = srv.call(req);
                async move {
                    let mut res: ServiceResponse<EitherBody<BoxBody>> = fut.await?;
                    res.headers_mut()
                        .insert(header::SERVER, "XYLEX/0".parse().unwrap());
                    Ok(res)
                }
            })
            .app_data(app_state.clone())
            .service(root)
            .service(ping)
            .service(sql_query)
            .service(fetch_data_route)
            .service(get_data_route)
            .service(proxy_fetch_data_route)
            .service(gateway_update_route)
            .service(gateway_query_route)
            .service(insert_data)
            .service(delete_data)
            .service(postgrest_get_route)
            .service(postgrest_post_route)
            .service(postgrest_patch_route)
            .service(postgrest_delete_route)
            .service(run_pipeline)
            .service(athena_router_registry)
            .service(athena_openapi_host)
            .service(api_registry)
            .service(athena_docs)
            .service(api_registry_by_id)
            .configure(admin::services)
            .configure(schema::services)
            .service(ssl_enforcement);
        if prometheus_metrics_enabled {
            app = app.service(prometheus_metrics_stub);
        }
        app
    })
    .workers(worker_count)
    .keep_alive(keep_alive)
    .client_disconnect_timeout(client_disconnect_timeout)
    .client_request_timeout(client_request_timeout)
    .max_connections(max_connections)
    .backlog(backlog as u32)
    .listen(listener)?
    .run()
    .await
}

/// Configures tracing with chrono timestamps and an environment-configurable filter.
fn init_tracing(enable_sentry_layer: bool) {
    let filter: EnvFilter =
        EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
    let fmt_layer =
        tracing_subscriber::fmt::layer().with_timer(ChronoLocal::new("%H:%M:%S%.3f".to_string()));

    let base = tracing_subscriber::registry().with(filter).with(fmt_layer);

    if enable_sentry_layer {
        base.with(sentry_tracing::layer()).init();
    } else {
        base.init();
    }
}

const BETTERSTACK_SENTRY_DSN: &str =
    "https://mMR1Bs5K6vSzXT8YGYyUxQSE@s1741777.eu-fsn-3.betterstackdata.com/1741777";

fn init_sentry() -> Option<sentry::ClientInitGuard> {
    let dsn_source = env::var("SENTRY_DSN").unwrap_or_else(|_| BETTERSTACK_SENTRY_DSN.to_string());

    if dsn_source.trim().is_empty() {
        return None;
    }

    let dsn = match dsn_source.parse::<sentry::types::Dsn>() {
        Ok(dsn) => dsn,
        Err(err) => {
            eprintln!("failed to parse Sentry DSN: {err}");
            return None;
        }
    };

    let mut options = sentry::ClientOptions {
        dsn: Some(dsn),
        release: Some(env!("CARGO_PKG_VERSION").into()),
        environment: env::var("SENTRY_ENVIRONMENT").ok().map(Into::into),
        attach_stacktrace: true,
        ..Default::default()
    };

    if let Ok(value) = env::var("SENTRY_SAMPLE_RATE") {
        if let Ok(parsed) = value.parse::<f32>() {
            options.sample_rate = parsed;
        }
    }

    Some(sentry::init(options))
}

fn format_attempted_locations(locations: &[ConfigLocation]) -> String {
    locations
        .iter()
        .map(|location| format!("- {}", location.describe()))
        .collect::<Vec<_>>()
        .join("\n")
}