use std::net::SocketAddr;
use std::sync::Arc;
use axum::http::HeaderValue;
use deadpool_postgres::{Config, Runtime};
use tokio_postgres::NoTls;
use tower_governor::GovernorLayer;
use tower_governor::governor::GovernorConfigBuilder;
use tower_http::cors::{AllowOrigin, CorsLayer};
pub mod arrow_encode;
pub mod common;
pub mod datalog;
pub mod metrics;
pub mod routing;
pub mod spi_bridge;
pub mod stream;
use common::{AppState, env_or};
const COMPATIBLE_EXTENSION_MIN: &str = "0.115.0";
async fn check_extension_compatibility(client: &deadpool_postgres::Object) {
if std::env::var("PG_RIPPLE_HTTP_SKIP_COMPAT_CHECK")
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false)
{
tracing::debug!(
"PG_RIPPLE_HTTP_SKIP_COMPAT_CHECK=1: skipping extension compatibility check"
);
return;
}
let strict = std::env::var("PG_RIPPLE_HTTP_STRICT_COMPAT")
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false);
let ext_version = match client
.query_opt(
"SELECT extversion FROM pg_extension WHERE extname = 'pg_ripple'",
&[],
)
.await
{
Ok(Some(row)) => row.get::<_, String>(0),
Ok(None) => {
tracing::warn!(
"pg_ripple extension not found in pg_extension catalog — \
compatibility check skipped"
);
return;
}
Err(e) => {
tracing::warn!("could not query pg_ripple extension version: {e}");
return;
}
};
tracing::info!(
ext_version = %ext_version,
min_supported = %COMPATIBLE_EXTENSION_MIN,
"pg_ripple extension compatibility check"
);
if semver_lt(&ext_version, COMPATIBLE_EXTENSION_MIN) {
if strict {
tracing::error!(
ext_version = %ext_version,
min_supported = %COMPATIBLE_EXTENSION_MIN,
"PG_RIPPLE_HTTP_STRICT_COMPAT=1: extension version is below minimum — aborting"
);
std::process::exit(1);
}
tracing::warn!(
ext_version = %ext_version,
min_supported = %COMPATIBLE_EXTENSION_MIN,
"pg_ripple extension version is below the minimum supported by this pg_ripple_http \
build — some features may not work correctly. \
Upgrade the extension with: ALTER EXTENSION pg_ripple UPDATE; \
or set PG_RIPPLE_HTTP_SKIP_COMPAT_CHECK=1 to suppress this warning. \
Set PG_RIPPLE_HTTP_STRICT_COMPAT=1 to make this a fatal startup error."
);
}
}
fn semver_lt(version: &str, min: &str) -> bool {
parse_semver(version)
.zip(parse_semver(min))
.map(|(v, m)| v < m)
.unwrap_or(false)
}
fn parse_semver(s: &str) -> Option<(u32, u32, u32)> {
let mut parts = s.splitn(3, '.');
let major = parts.next()?.parse::<u32>().ok()?;
let minor = parts.next()?.parse::<u32>().ok()?;
let patch = parts.next()?.split('-').next()?.parse::<u32>().ok()?;
Some((major, minor, patch))
}
#[tokio::main]
async fn main() {
let log_format = std::env::var("RUST_LOG_FORMAT").unwrap_or_default();
let env_filter = tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| {
"pg_ripple_http=info".parse().unwrap_or_else(|e| {
eprintln!("error parsing log filter: {e}");
std::process::exit(1);
})
});
if log_format.eq_ignore_ascii_case("json") {
tracing_subscriber::fmt()
.json()
.with_env_filter(env_filter)
.init();
} else {
tracing_subscriber::fmt().with_env_filter(env_filter).init();
}
let pg_url = {
let args: Vec<String> = std::env::args().collect();
if args.len() > 1 {
args[1].clone()
} else {
env_or("PG_RIPPLE_HTTP_PG_URL", "postgresql://localhost/postgres")
}
};
let port: u16 = match env_or("PG_RIPPLE_HTTP_PORT", "7878").parse() {
Ok(p) => p,
Err(e) => {
tracing::error!("PG_RIPPLE_HTTP_PORT must be a valid port number: {e}");
std::process::exit(1);
}
};
let pool_size: usize = match env_or("PG_RIPPLE_HTTP_POOL_SIZE", "16").parse() {
Ok(n) => n,
Err(e) => {
tracing::error!("PG_RIPPLE_HTTP_POOL_SIZE must be a positive integer: {e}");
std::process::exit(1);
}
};
let auth_token = std::env::var("PG_RIPPLE_HTTP_AUTH_TOKEN").ok();
let datalog_write_token = std::env::var("PG_RIPPLE_HTTP_DATALOG_WRITE_TOKEN").ok();
let rate_limit: u32 = match env_or("PG_RIPPLE_HTTP_RATE_LIMIT", "100").parse() {
Ok(r) => r,
Err(e) => {
tracing::error!("PG_RIPPLE_HTTP_RATE_LIMIT must be a non-negative integer: {e}");
std::process::exit(1);
}
};
let cors_origins = env_or("PG_RIPPLE_HTTP_CORS_ORIGINS", "");
let max_body_bytes: usize = match env_or("PG_RIPPLE_HTTP_MAX_BODY_BYTES", "10485760").parse() {
Ok(n) => n,
Err(e) => {
tracing::error!("PG_RIPPLE_HTTP_MAX_BODY_BYTES must be a positive integer: {e}");
std::process::exit(1);
}
};
let trust_proxy = std::env::var("PG_RIPPLE_HTTP_TRUST_PROXY").ok();
if let Ok(ca_path) = std::env::var("PG_RIPPLE_HTTP_CA_BUNDLE") {
match std::fs::read_to_string(&ca_path) {
Ok(pem) if !pem.trim().is_empty() && pem.contains("BEGIN CERTIFICATE") => {
tracing::info!("PG_RIPPLE_HTTP_CA_BUNDLE: loaded CA bundle from {ca_path}");
unsafe { std::env::set_var("PG_RIPPLE_HTTP_CA_PEM", pem) };
}
Ok(_) => {
tracing::error!(
"PG_RIPPLE_HTTP_CA_BUNDLE: file at '{ca_path}' is not a valid PEM bundle \
(no 'BEGIN CERTIFICATE' marker) — falling back to system trust store"
);
}
Err(e) => {
tracing::error!(
"PG_RIPPLE_HTTP_CA_BUNDLE: cannot read '{ca_path}': {e} \
— falling back to system trust store"
);
}
}
}
if let Ok(fps) = std::env::var("PG_RIPPLE_HTTP_PIN_FINGERPRINTS") {
let count = fps.split(',').filter(|s| !s.trim().is_empty()).count();
if count == 0 {
tracing::warn!(
"PG_RIPPLE_HTTP_PIN_FINGERPRINTS is set but contains no valid fingerprints \
— pinning is disabled"
);
} else {
tracing::info!(
"PG_RIPPLE_HTTP_PIN_FINGERPRINTS: {count} pinned certificate fingerprint(s) loaded"
);
}
}
let mut cfg = Config::new();
cfg.url = Some(pg_url.clone());
cfg.pool = Some(deadpool_postgres::PoolConfig::new(pool_size));
let pool = match cfg.create_pool(Some(Runtime::Tokio1), NoTls) {
Ok(p) => p,
Err(e) => {
tracing::error!("failed to create PostgreSQL connection pool: {e}");
std::process::exit(1);
}
};
{
let client = match pool.get().await {
Ok(c) => c,
Err(e) => {
tracing::error!(
"failed to connect to PostgreSQL — check PG_RIPPLE_HTTP_PG_URL: {e}"
);
std::process::exit(1);
}
};
let row = match client
.query_one("SELECT pg_ripple.triple_count()", &[])
.await
{
Ok(r) => r,
Err(e) => {
tracing::error!("pg_ripple extension not available — is it installed? ({e})");
std::process::exit(1);
}
};
let count: i64 = row.get(0);
tracing::info!(
"connected to {pg_url} (port {port}), triple store contains {count} triples"
);
check_extension_compatibility(&client).await;
}
let cors_is_permissive = cors_origins == "*";
let metrics_token = std::env::var("PG_RIPPLE_HTTP_METRICS_TOKEN").ok();
if metrics_token.is_some() {
tracing::info!(
"PG_RIPPLE_HTTP_METRICS_TOKEN set: GET /metrics requires Authorization: Bearer <token>"
);
}
let state = Arc::new(AppState {
pool,
auth_token,
datalog_write_token,
trust_proxy,
metrics: metrics::Metrics::new(),
ever_connected: std::sync::atomic::AtomicBool::new(false),
arrow_flight_secret: std::env::var("ARROW_FLIGHT_SECRET").ok(),
arrow_unsigned_tickets_allowed: std::env::var("ARROW_UNSIGNED_TICKETS_ALLOWED")
.map(|v| v.eq_ignore_ascii_case("true") || v == "1")
.unwrap_or(false),
arrow_nonce_cache: dashmap::DashMap::new(),
arrow_nonce_cache_max: std::env::var("ARROW_NONCE_CACHE_MAX")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.unwrap_or(10_000),
cors_is_permissive,
metrics_token,
});
let cors = if cors_is_permissive {
tracing::warn!(
"CORS is permissive (*). Set PG_RIPPLE_HTTP_CORS_ORIGINS to a comma-separated list of allowed origins for production use. \
Monitor pg_ripple_http_cors_permissive_requests_total for cross-origin traffic."
);
CorsLayer::permissive()
} else if cors_origins.is_empty() {
CorsLayer::new()
} else {
let origins: Vec<HeaderValue> = cors_origins
.split(',')
.filter_map(|o| o.trim().parse().ok())
.collect();
CorsLayer::new().allow_origin(AllowOrigin::list(origins))
};
let mut app = routing::build_router(state.clone(), max_body_bytes, cors);
if rate_limit > 0 {
let governor_conf = match GovernorConfigBuilder::default()
.per_second(rate_limit as u64)
.burst_size(rate_limit)
.finish()
{
Some(c) => c,
None => {
tracing::error!("invalid governor rate-limit configuration");
std::process::exit(1);
}
};
app = app.layer(GovernorLayer::new(Arc::new(governor_conf)));
}
let addr = SocketAddr::from(([0, 0, 0, 0], port));
tracing::info!("pg_ripple_http listening on http://{addr}");
let listener = match tokio::net::TcpListener::bind(addr).await {
Ok(l) => l,
Err(e) => {
tracing::error!("failed to bind TCP listener on {addr}: {e}");
std::process::exit(1);
}
};
if let Err(e) = axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(shutdown_signal())
.await
{
tracing::error!("server error: {e}");
std::process::exit(1);
}
}
async fn shutdown_signal() {
use tokio::signal;
let shutdown_timeout_secs: u64 = std::env::var("PG_RIPPLE_HTTP_SHUTDOWN_TIMEOUT_SECS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(30);
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install SIGTERM handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
() = ctrl_c => {
tracing::info!(
"received Ctrl+C, initiating graceful shutdown ({shutdown_timeout_secs}s drain)"
);
}
() = terminate => {
tracing::info!(
"received SIGTERM, initiating graceful shutdown ({shutdown_timeout_secs}s drain)"
);
}
}
if shutdown_timeout_secs > 0 {
tokio::time::sleep(std::time::Duration::from_secs(shutdown_timeout_secs)).await;
}
}