mod models;
mod handlers;
mod auth;
mod database;
mod error;
mod introspection;
mod setup;
mod config;
mod rate_limit;
mod connection_limit;
mod license;
mod observability;
use axum::{
middleware,
routing::{get, post, delete},
Router,
};
use std::net::SocketAddr;
use std::env;
use tower_http::cors::CorsLayer;
use tracing::info;
use crate::models::AppState;
use crate::config::load_server_config;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let server_config = load_server_config().await?;
tracing_subscriber::fmt()
.with_env_filter(format!("pg_api={}", server_config.log_level))
.json()
.init();
let args: Vec<String> = env::args().collect();
if args.len() > 1 && args[1] == "setup" {
return setup::run_setup().await;
}
let state = AppState::new().await?;
let observability_config = observability::ObservabilityConfig::from_env();
let observability_client = std::sync::Arc::new(
observability::ObservabilityClient::new(observability_config.clone())
);
if observability_config.enabled {
observability_client.clone().start_flush_task();
info!("Observability enabled, sending metrics to: {:?}",
observability_config.opensearch_url);
} else {
info!("Observability disabled");
}
let app = Router::new()
.route("/", get(handlers::serve_docs))
.route("/docs", get(handlers::serve_docs))
.route("/openapi.json", get(handlers::serve_openapi))
.route("/health", get(handlers::health_check))
.route("/v1/status", get(handlers::status_handler))
.route("/v1/license", get(license::get_license_info))
.route("/v1/query", post(handlers::query_handler))
.route("/v1/batch", post(handlers::batch_query_handler))
.route("/v1/transaction", post(handlers::transaction_handler))
.route("/v1/databases", get(handlers::list_databases))
.route("/v1/databases", post(handlers::create_database))
.route("/v1/databases/{name}", delete(handlers::drop_database))
.route("/v1/databases/{db}/tables", get(handlers::list_tables))
.route("/v1/databases/{db}/schema", get(handlers::get_schema))
.route("/v1/account", get(handlers::get_account_info))
.route("/v1/account/usage", get(handlers::get_usage_stats))
.layer(middleware::from_fn_with_state(observability_client.clone(), observability::metrics_middleware))
.layer(middleware::from_fn_with_state(state.clone(), license::check_license_middleware))
.layer(middleware::from_fn_with_state(state.clone(), connection_limit::connection_limit_middleware))
.layer(middleware::from_fn_with_state(state.clone(), rate_limit::rate_limit_middleware))
.layer(middleware::from_fn_with_state(state.clone(), auth::auth_middleware))
.layer(middleware::from_fn(auth::request_id_middleware))
.layer(CorsLayer::permissive())
.with_state(state);
let addr = SocketAddr::new(server_config.host, server_config.port);
let listener = tokio::net::TcpListener::bind(addr).await?;
info!("pg-api running on {}", addr);
info!("Documentation available at http://{}/docs", addr);
axum::serve(listener, app).await?;
Ok(())
}