use actix_cors::Cors;
use actix_web::body::{BoxBody, EitherBody};
use actix_web::dev::{Service, ServiceResponse};
use actix_web::http::header;
use actix_web::{App, HttpServer, web};
use clap::Parser;
use dotenv::dotenv;
use std::env;
use std::io::Error;
use std::io::ErrorKind::Other;
use std::io::Result as IoResult;
use std::net::{SocketAddr, TcpListener};
use std::path::PathBuf;
use std::thread::available_parallelism;
use std::time::Duration;
use tracing::info;
use tracing_subscriber::EnvFilter;
use tracing_subscriber::fmt::time::ChronoLocal;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use socket2::{Domain, Protocol, Socket, TcpKeepalive, Type};
use athena_rs::AppState;
use athena_rs::api::gateway::delete::delete_data;
use athena_rs::api::gateway::fetch::{
fetch_data_route, gateway_update_route, get_data_route, proxy_fetch_data_route,
};
use athena_rs::api::gateway::insert::insert_data;
use athena_rs::api::gateway::postgrest::{
postgrest_delete_route, postgrest_get_route, postgrest_patch_route, postgrest_post_route,
};
use athena_rs::api::gateway::query::gateway_query_route;
use athena_rs::api::health::{ping, root};
use athena_rs::api::metrics::prometheus_metrics_stub;
use athena_rs::api::pipelines::run_pipeline;
use athena_rs::api::query::sql::sql_query;
use athena_rs::api::registry::{api_registry, api_registry_by_id};
use athena_rs::api::supabase::ssl_enforcement;
use athena_rs::api::{admin, athena_docs, schema};
use athena_rs::api::{athena_openapi_host, athena_router_registry};
use athena_rs::bootstrap::{Bootstrap, build_shared_state};
use athena_rs::cli::{self, AthenaCli, Command};
use athena_rs::config::{Config, ConfigLocation, DEFAULT_CONFIG_FILE_NAME};
use athena_rs::parser::{parse_secs_or_default, parse_usize};
#[actix_web::main]
async fn main() -> IoResult<()> {
dotenv().ok();
let sentry_guard: Option<sentry::ClientInitGuard> = init_sentry();
init_tracing(sentry_guard.is_some());
let _sentry_guard = sentry_guard;
let cli = AthenaCli::parse();
let config_path = cli.config_path.clone();
let pipelines_path = cli.pipelines_path.clone();
let port_override = cli.port;
let command = cli.command;
let api_only = cli.api_only;
#[cfg(feature = "cdc")]
let cdc_only = cli.cdc_only;
let config_overridden = config_path.is_some();
let config_path = config_path
.clone()
.unwrap_or_else(|| PathBuf::from(DEFAULT_CONFIG_FILE_NAME));
let pipelines_path: String = pipelines_path.to_string_lossy().to_string();
let config: Config = if config_overridden {
Config::load_from(&config_path).map_err(|err| {
let attempted_locations = vec![ConfigLocation::new(
"explicit config path".to_string(),
config_path.clone(),
)];
Error::new(
Other,
format!(
"Failed to load config '{}': {}. Looked in:\n{}",
config_path.display(),
err,
format_attempted_locations(&attempted_locations)
),
)
})?
} else {
match Config::load_default() {
Ok(outcome) => {
if outcome.seeded_default {
info!("Seeded default configuration at {}", outcome.path.display());
}
outcome.config
}
Err(err) => {
let attempts = err.attempted_locations.clone();
return Err(Error::new(
Other,
format!(
"Failed to load config '{}': {}. Looked in:\n{}",
config_path.display(),
err,
format_attempted_locations(&attempts)
),
));
}
}
};
match command {
Some(Command::Server) => {
let bootstrap = build_shared_state(&config, &pipelines_path)
.await
.map_err(|err| Error::new(Other, err.to_string()))?;
run_server(&config, bootstrap, port_override).await
}
Some(Command::Pipeline(args)) => {
let bootstrap = build_shared_state(&config, &pipelines_path)
.await
.map_err(|err| Error::new(Other, err.to_string()))?;
cli::run_pipeline_command(&bootstrap, args)
.await
.map_err(|err| Error::new(Other, err.to_string()))
}
Some(Command::Clients { command }) => {
cli::run_clients_command(command).map_err(|err| Error::new(Other, err.to_string()))
}
Some(Command::Fetch(cmd)) => cli::run_fetch_command(cmd)
.await
.map_err(|err| Error::new(Other, err.to_string())),
#[cfg(feature = "cdc")]
Some(Command::Cdc { command }) => cli::run_cdc_command(&config, command)
.await
.map_err(|err| Error::new(Other, err.to_string())),
Some(Command::Diag) => {
cli::run_diag_command().map_err(|err| Error::new(Other, err.to_string()))
}
Some(Command::Version) => {
cli::run_version_command();
Ok(())
}
None => {
#[cfg(feature = "cdc")]
{
if cdc_only {
let port: u16 = port_override.unwrap_or(4053);
athena_rs::cdc::websocket::websocket_server(port)
.await
.map_err(|err| Error::new(Other, err.to_string()))
} else if api_only {
let bootstrap: Bootstrap = build_shared_state(&config, &pipelines_path)
.await
.map_err(|err| Error::new(Other, err.to_string()))?;
run_server(&config, bootstrap, port_override).await
} else {
Err(Error::new(
Other,
"No command provided; pass --api-only to boot the API, --cdc-only for the CDC WebSocket server, or specify a subcommand.",
))
}
}
#[cfg(not(feature = "cdc"))]
{
if api_only {
let bootstrap: Bootstrap = build_shared_state(&config, &pipelines_path)
.await
.map_err(|err| Error::new(Other, err.to_string()))?;
run_server(&config, bootstrap, port_override).await
} else {
Err(Error::new(
Other,
"No command provided; pass --api-only to boot the API or specify a subcommand.",
))
}
}
}
}
}
async fn run_server(
config: &Config,
bootstrap: Bootstrap,
port_override: Option<u16>,
) -> IoResult<()> {
let port: u16 = if let Some(port) = port_override {
port
} else {
config
.get_api()
.ok_or("No API port configured")
.and_then(|port_str| port_str.parse().map_err(|_| "Invalid port number"))
.expect("Failed to parse API port")
};
let keep_alive: Duration = parse_secs_or_default(config.get_http_keep_alive_secs(), 15);
let client_disconnect_timeout =
parse_secs_or_default(config.get_client_disconnect_timeout_secs(), 60);
let client_request_timeout =
parse_secs_or_default(config.get_client_request_timeout_secs(), 60);
let worker_count: usize = config
.get_http_workers()
.and_then(parse_usize)
.unwrap_or_else(|| available_parallelism().map(|n| n.get()).unwrap_or(4));
let max_connections: usize = config
.get_http_max_connections()
.and_then(parse_usize)
.unwrap_or(10_000);
let backlog: usize = config
.get_http_backlog()
.and_then(parse_usize)
.unwrap_or(2_048);
let tcp_keepalive: Duration = parse_secs_or_default(config.get_tcp_keepalive_secs(), 75);
let prometheus_metrics_enabled: bool = config.get_prometheus_metrics_enabled();
let addr: SocketAddr = SocketAddr::from(([0, 0, 0, 0], port));
let socket: Socket = Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))?;
socket.set_nonblocking(true)?;
socket.set_keepalive(true)?;
let keepalive_cfg: TcpKeepalive = TcpKeepalive::new().with_time(tcp_keepalive);
socket.set_tcp_keepalive(&keepalive_cfg)?;
socket.bind(&addr.into())?;
let listen_backlog: i32 = backlog.min(i32::MAX as usize) as i32;
socket.listen(listen_backlog)?;
let listener: TcpListener = socket.into();
let app_state: web::Data<AppState> = bootstrap.app_state.clone();
HttpServer::new(move || {
let cors: Cors = Cors::default()
.allow_any_origin()
.allow_any_method()
.allow_any_header();
let mut app = App::new()
.wrap(cors)
.wrap_fn(|req, srv| {
let fut = srv.call(req);
async move {
let mut res: ServiceResponse<EitherBody<BoxBody>> = fut.await?;
res.headers_mut()
.insert(header::SERVER, "XYLEX/0".parse().unwrap());
Ok(res)
}
})
.app_data(app_state.clone())
.service(root)
.service(ping)
.service(sql_query)
.service(fetch_data_route)
.service(get_data_route)
.service(proxy_fetch_data_route)
.service(gateway_update_route)
.service(gateway_query_route)
.service(insert_data)
.service(delete_data)
.service(postgrest_get_route)
.service(postgrest_post_route)
.service(postgrest_patch_route)
.service(postgrest_delete_route)
.service(run_pipeline)
.service(athena_router_registry)
.service(athena_openapi_host)
.service(api_registry)
.service(athena_docs)
.service(api_registry_by_id)
.configure(admin::services)
.configure(schema::services)
.service(ssl_enforcement);
if prometheus_metrics_enabled {
app = app.service(prometheus_metrics_stub);
}
app
})
.workers(worker_count)
.keep_alive(keep_alive)
.client_disconnect_timeout(client_disconnect_timeout)
.client_request_timeout(client_request_timeout)
.max_connections(max_connections)
.backlog(backlog as u32)
.listen(listener)?
.run()
.await
}
fn init_tracing(enable_sentry_layer: bool) {
let filter: EnvFilter =
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
let fmt_layer =
tracing_subscriber::fmt::layer().with_timer(ChronoLocal::new("%H:%M:%S%.3f".to_string()));
let base = tracing_subscriber::registry().with(filter).with(fmt_layer);
if enable_sentry_layer {
base.with(sentry_tracing::layer()).init();
} else {
base.init();
}
}
const BETTERSTACK_SENTRY_DSN: &str =
"https://mMR1Bs5K6vSzXT8YGYyUxQSE@s1741777.eu-fsn-3.betterstackdata.com/1741777";
fn init_sentry() -> Option<sentry::ClientInitGuard> {
let dsn_source = env::var("SENTRY_DSN").unwrap_or_else(|_| BETTERSTACK_SENTRY_DSN.to_string());
if dsn_source.trim().is_empty() {
return None;
}
let dsn = match dsn_source.parse::<sentry::types::Dsn>() {
Ok(dsn) => dsn,
Err(err) => {
eprintln!("failed to parse Sentry DSN: {err}");
return None;
}
};
let mut options = sentry::ClientOptions {
dsn: Some(dsn),
release: Some(env!("CARGO_PKG_VERSION").into()),
environment: env::var("SENTRY_ENVIRONMENT").ok().map(Into::into),
attach_stacktrace: true,
..Default::default()
};
if let Ok(value) = env::var("SENTRY_SAMPLE_RATE") {
if let Ok(parsed) = value.parse::<f32>() {
options.sample_rate = parsed;
}
}
Some(sentry::init(options))
}
fn format_attempted_locations(locations: &[ConfigLocation]) -> String {
locations
.iter()
.map(|location| format!("- {}", location.describe()))
.collect::<Vec<_>>()
.join("\n")
}