use std::sync::Arc;
use axum::{
Router, middleware,
routing::{get, post},
};
#[cfg(feature = "arrow")]
use fraiseql_arrow::FraiseQLFlightService;
use fraiseql_core::{
db::traits::DatabaseAdapter,
runtime::{Executor, SubscriptionManager},
schema::CompiledSchema,
security::OidcValidator,
};
use tokio::net::TcpListener;
#[cfg(feature = "observers")]
use tracing::error;
use tracing::{info, warn};
#[cfg(feature = "observers")]
use {
crate::observers::{ObserverRuntime, ObserverRuntimeConfig},
tokio::sync::RwLock,
};
use crate::{
Result, ServerError,
middleware::{
BearerAuthState, OidcAuthState, RateLimiter, bearer_auth_middleware, cors_layer_restricted,
metrics_middleware, oidc_auth_middleware, trace_layer,
},
routes::{
PlaygroundState, SubscriptionState, api, graphql::AppState, graphql_get_handler,
graphql_handler, health_handler, introspection_handler, metrics_handler,
metrics_json_handler, playground_handler, subscription_handler,
},
server_config::ServerConfig,
tls::TlsSetup,
};
pub struct Server<A: DatabaseAdapter> {
config: ServerConfig,
executor: Arc<Executor<A>>,
subscription_manager: Arc<SubscriptionManager>,
oidc_validator: Option<Arc<OidcValidator>>,
rate_limiter: Option<Arc<RateLimiter>>,
#[cfg(feature = "observers")]
observer_runtime: Option<Arc<RwLock<ObserverRuntime>>>,
#[cfg(feature = "observers")]
db_pool: Option<sqlx::PgPool>,
#[cfg(feature = "arrow")]
flight_service: Option<FraiseQLFlightService>,
}
impl<A: DatabaseAdapter + Clone + Send + Sync + 'static> Server<A> {
pub async fn new(
config: ServerConfig,
schema: CompiledSchema,
adapter: Arc<A>,
#[allow(unused_variables)] db_pool: Option<sqlx::PgPool>,
) -> Result<Self> {
let executor = Arc::new(Executor::new(schema.clone(), adapter));
let subscription_manager = Arc::new(SubscriptionManager::new(Arc::new(schema)));
let oidc_validator = if let Some(ref auth_config) = config.auth {
info!(
issuer = %auth_config.issuer,
"Initializing OIDC authentication"
);
let validator = OidcValidator::new(auth_config.clone())
.await
.map_err(|e| ServerError::ConfigError(format!("Failed to initialize OIDC: {e}")))?;
Some(Arc::new(validator))
} else {
None
};
let rate_limiter = if let Some(ref rate_config) = config.rate_limiting {
if rate_config.enabled {
info!(
rps_per_ip = rate_config.rps_per_ip,
rps_per_user = rate_config.rps_per_user,
"Initializing rate limiting"
);
let limiter_config = crate::middleware::RateLimitConfig {
enabled: true,
rps_per_ip: rate_config.rps_per_ip,
rps_per_user: rate_config.rps_per_user,
burst_size: rate_config.burst_size,
cleanup_interval_secs: rate_config.cleanup_interval_secs,
};
Some(Arc::new(RateLimiter::new(limiter_config)))
} else {
info!("Rate limiting disabled by configuration");
None
}
} else {
None
};
#[cfg(feature = "observers")]
let observer_runtime = Self::init_observer_runtime(&config, db_pool.as_ref()).await;
#[cfg(feature = "arrow")]
let flight_service = {
let mut service = FraiseQLFlightService::new();
if let Some(ref validator) = oidc_validator {
info!("Enabling OIDC authentication for Arrow Flight");
service.set_oidc_validator(validator.clone());
} else {
info!("Arrow Flight initialized without authentication (dev mode)");
}
Some(service)
};
Ok(Self {
config,
executor,
subscription_manager,
oidc_validator,
rate_limiter,
#[cfg(feature = "observers")]
observer_runtime,
#[cfg(feature = "observers")]
db_pool,
#[cfg(feature = "arrow")]
flight_service,
})
}
#[cfg(feature = "arrow")]
pub async fn with_flight_service(
config: ServerConfig,
schema: CompiledSchema,
adapter: Arc<A>,
#[allow(unused_variables)] db_pool: Option<sqlx::PgPool>,
flight_service: Option<FraiseQLFlightService>,
) -> Result<Self> {
let executor = Arc::new(Executor::new(schema.clone(), adapter));
let subscription_manager = Arc::new(SubscriptionManager::new(Arc::new(schema)));
let oidc_validator = if let Some(ref auth_config) = config.auth {
info!(
issuer = %auth_config.issuer,
"Initializing OIDC authentication"
);
let validator = OidcValidator::new(auth_config.clone())
.await
.map_err(|e| ServerError::ConfigError(format!("Failed to initialize OIDC: {e}")))?;
Some(Arc::new(validator))
} else {
None
};
let rate_limiter = if let Some(ref rate_config) = config.rate_limiting {
if rate_config.enabled {
info!(
rps_per_ip = rate_config.rps_per_ip,
rps_per_user = rate_config.rps_per_user,
"Initializing rate limiting"
);
let limiter_config = crate::middleware::RateLimitConfig {
enabled: true,
rps_per_ip: rate_config.rps_per_ip,
rps_per_user: rate_config.rps_per_user,
burst_size: rate_config.burst_size,
cleanup_interval_secs: rate_config.cleanup_interval_secs,
};
Some(Arc::new(RateLimiter::new(limiter_config)))
} else {
info!("Rate limiting disabled by configuration");
None
}
} else {
None
};
#[cfg(feature = "observers")]
let observer_runtime = Self::init_observer_runtime(&config, db_pool.as_ref()).await;
Ok(Self {
config,
executor,
subscription_manager,
oidc_validator,
rate_limiter,
#[cfg(feature = "observers")]
observer_runtime,
#[cfg(feature = "observers")]
db_pool,
flight_service,
})
}
#[cfg(feature = "observers")]
async fn init_observer_runtime(
config: &ServerConfig,
pool: Option<&sqlx::PgPool>,
) -> Option<Arc<RwLock<ObserverRuntime>>> {
let observer_config = match &config.observers {
Some(cfg) if cfg.enabled => cfg,
_ => {
info!("Observer runtime disabled");
return None;
},
};
let pool = match pool {
Some(p) => p,
None => {
warn!("No database pool provided for observers");
return None;
},
};
info!("Initializing observer runtime");
let runtime_config = ObserverRuntimeConfig::new(pool.clone())
.with_poll_interval(observer_config.poll_interval_ms)
.with_batch_size(observer_config.batch_size)
.with_channel_capacity(observer_config.channel_capacity);
let runtime = ObserverRuntime::new(runtime_config);
Some(Arc::new(RwLock::new(runtime)))
}
fn build_router(&self) -> Router {
let state = AppState::new(self.executor.clone());
let metrics = state.metrics.clone();
let graphql_router = if let Some(ref validator) = self.oidc_validator {
info!(
graphql_path = %self.config.graphql_path,
"GraphQL endpoint protected by OIDC authentication (GET and POST)"
);
let auth_state = OidcAuthState::new(validator.clone());
Router::new()
.route(
&self.config.graphql_path,
get(graphql_get_handler::<A>).post(graphql_handler::<A>),
)
.route_layer(middleware::from_fn_with_state(auth_state, oidc_auth_middleware))
.with_state(state.clone())
} else {
Router::new()
.route(
&self.config.graphql_path,
get(graphql_get_handler::<A>).post(graphql_handler::<A>),
)
.with_state(state.clone())
};
let mut app = Router::new()
.route(&self.config.health_path, get(health_handler::<A>))
.with_state(state.clone())
.merge(graphql_router);
if self.config.playground_enabled {
let playground_state =
PlaygroundState::new(self.config.graphql_path.clone(), self.config.playground_tool);
info!(
playground_path = %self.config.playground_path,
playground_tool = ?self.config.playground_tool,
"GraphQL playground enabled"
);
let playground_router = Router::new()
.route(&self.config.playground_path, get(playground_handler))
.with_state(playground_state);
app = app.merge(playground_router);
}
if self.config.subscriptions_enabled {
let subscription_state = SubscriptionState::new(self.subscription_manager.clone());
info!(
subscription_path = %self.config.subscription_path,
"GraphQL subscriptions enabled (graphql-ws protocol)"
);
let subscription_router = Router::new()
.route(&self.config.subscription_path, get(subscription_handler))
.with_state(subscription_state);
app = app.merge(subscription_router);
}
if self.config.introspection_enabled {
if self.config.introspection_require_auth {
if let Some(ref validator) = self.oidc_validator {
info!(
introspection_path = %self.config.introspection_path,
"Introspection endpoint enabled (OIDC auth required)"
);
let auth_state = OidcAuthState::new(validator.clone());
let introspection_router = Router::new()
.route(&self.config.introspection_path, get(introspection_handler::<A>))
.route_layer(middleware::from_fn_with_state(
auth_state.clone(),
oidc_auth_middleware,
))
.with_state(state.clone());
app = app.merge(introspection_router);
let schema_router = Router::new()
.route("/api/v1/schema.graphql", get(api::schema::export_sdl_handler::<A>))
.route("/api/v1/schema.json", get(api::schema::export_json_handler::<A>))
.route_layer(middleware::from_fn_with_state(
auth_state,
oidc_auth_middleware,
))
.with_state(state.clone());
app = app.merge(schema_router);
} else {
warn!(
"introspection_require_auth is true but no OIDC configured - introspection and schema export disabled"
);
}
} else {
info!(
introspection_path = %self.config.introspection_path,
"Introspection endpoint enabled (no auth required - USE ONLY IN DEVELOPMENT)"
);
let introspection_router = Router::new()
.route(&self.config.introspection_path, get(introspection_handler::<A>))
.with_state(state.clone());
app = app.merge(introspection_router);
let schema_router = Router::new()
.route("/api/v1/schema.graphql", get(api::schema::export_sdl_handler::<A>))
.route("/api/v1/schema.json", get(api::schema::export_json_handler::<A>))
.with_state(state.clone());
app = app.merge(schema_router);
}
}
if self.config.metrics_enabled {
if let Some(ref token) = self.config.metrics_token {
info!(
metrics_path = %self.config.metrics_path,
metrics_json_path = %self.config.metrics_json_path,
"Metrics endpoints enabled (bearer token required)"
);
let auth_state = BearerAuthState::new(token.clone());
let metrics_router = Router::new()
.route(&self.config.metrics_path, get(metrics_handler::<A>))
.route(&self.config.metrics_json_path, get(metrics_json_handler::<A>))
.route_layer(middleware::from_fn_with_state(auth_state, bearer_auth_middleware))
.with_state(state.clone());
app = app.merge(metrics_router);
} else {
warn!(
"metrics_enabled is true but metrics_token is not set - metrics endpoints disabled"
);
}
}
if self.config.admin_api_enabled {
if let Some(ref token) = self.config.admin_token {
info!("Admin API endpoints enabled (bearer token required)");
let auth_state = BearerAuthState::new(token.clone());
let admin_router = Router::new()
.route(
"/api/v1/admin/reload-schema",
post(api::admin::reload_schema_handler::<A>),
)
.route("/api/v1/admin/cache/clear", post(api::admin::cache_clear_handler::<A>))
.route("/api/v1/admin/cache/stats", get(api::admin::cache_stats_handler::<A>))
.route("/api/v1/admin/config", get(api::admin::config_handler::<A>))
.route_layer(middleware::from_fn_with_state(auth_state, bearer_auth_middleware))
.with_state(state.clone());
app = app.merge(admin_router);
} else {
warn!(
"admin_api_enabled is true but admin_token is not set - admin endpoints disabled"
);
}
}
if self.config.design_api_require_auth {
if let Some(ref validator) = self.oidc_validator {
info!("Design audit API endpoints enabled (OIDC auth required)");
let auth_state = OidcAuthState::new(validator.clone());
let design_router = Router::new()
.route(
"/design/federation-audit",
post(api::design::federation_audit_handler::<A>),
)
.route("/design/cost-audit", post(api::design::cost_audit_handler::<A>))
.route("/design/cache-audit", post(api::design::cache_audit_handler::<A>))
.route("/design/auth-audit", post(api::design::auth_audit_handler::<A>))
.route(
"/design/compilation-audit",
post(api::design::compilation_audit_handler::<A>),
)
.route("/design/audit", post(api::design::overall_design_audit_handler::<A>))
.route_layer(middleware::from_fn_with_state(auth_state, oidc_auth_middleware))
.with_state(state.clone());
app = app.nest("/api/v1", design_router);
} else {
warn!(
"design_api_require_auth is true but no OIDC configured - design endpoints unprotected"
);
let design_router = Router::new()
.route(
"/design/federation-audit",
post(api::design::federation_audit_handler::<A>),
)
.route("/design/cost-audit", post(api::design::cost_audit_handler::<A>))
.route("/design/cache-audit", post(api::design::cache_audit_handler::<A>))
.route("/design/auth-audit", post(api::design::auth_audit_handler::<A>))
.route(
"/design/compilation-audit",
post(api::design::compilation_audit_handler::<A>),
)
.route("/design/audit", post(api::design::overall_design_audit_handler::<A>))
.with_state(state.clone());
app = app.nest("/api/v1", design_router);
}
} else {
info!("Design audit API endpoints enabled (no auth required)");
let design_router = Router::new()
.route("/design/federation-audit", post(api::design::federation_audit_handler::<A>))
.route("/design/cost-audit", post(api::design::cost_audit_handler::<A>))
.route("/design/cache-audit", post(api::design::cache_audit_handler::<A>))
.route("/design/auth-audit", post(api::design::auth_audit_handler::<A>))
.route(
"/design/compilation-audit",
post(api::design::compilation_audit_handler::<A>),
)
.route("/design/audit", post(api::design::overall_design_audit_handler::<A>))
.with_state(state.clone());
app = app.nest("/api/v1", design_router);
}
let api_router = api::routes(state.clone());
app = app.nest("/api/v1", api_router);
app = app.layer(middleware::from_fn_with_state(metrics, metrics_middleware));
#[cfg(feature = "observers")]
{
app = self.add_observer_routes(app);
}
if self.config.tracing_enabled {
app = app.layer(trace_layer());
}
if self.config.cors_enabled {
let origins = if self.config.cors_origins.is_empty() {
tracing::warn!(
"CORS enabled but no origins configured. Using localhost:3000 as default. \
Set cors_origins in config for production."
);
vec!["http://localhost:3000".to_string()]
} else {
self.config.cors_origins.clone()
};
app = app.layer(cors_layer_restricted(origins));
}
if let Some(ref limiter) = self.rate_limiter {
use std::net::SocketAddr;
use axum::extract::ConnectInfo;
info!("Enabling rate limiting middleware");
let limiter_clone = limiter.clone();
app = app.layer(middleware::from_fn(move |ConnectInfo(addr): ConnectInfo<SocketAddr>, req, next: axum::middleware::Next| {
let limiter = limiter_clone.clone();
async move {
let ip = addr.ip().to_string();
if !limiter.check_ip_limit(&ip).await {
warn!(ip = %ip, "IP rate limit exceeded");
use axum::http::StatusCode;
use axum::response::IntoResponse;
return (
StatusCode::TOO_MANY_REQUESTS,
[("Content-Type", "application/json"), ("Retry-After", "60")],
r#"{"errors":[{"message":"Rate limit exceeded. Please retry after 60 seconds."}]}"#,
).into_response();
}
let remaining = limiter.get_ip_remaining(&ip).await;
let mut response = next.run(req).await;
let headers = response.headers_mut();
if let Ok(limit_value) = format!("{}", limiter.config().rps_per_ip).parse() {
headers.insert("X-RateLimit-Limit", limit_value);
}
if let Ok(remaining_value) = format!("{}", remaining as u32).parse() {
headers.insert("X-RateLimit-Remaining", remaining_value);
}
response
}
}));
}
app
}
#[cfg(feature = "observers")]
fn add_observer_routes(&self, app: Router) -> Router {
use crate::observers::{
ObserverRepository, ObserverState, RuntimeHealthState, observer_routes,
observer_runtime_routes,
};
let observer_state = ObserverState {
repository: ObserverRepository::new(
self.db_pool.clone().expect("Pool required for observers"),
),
};
let app = app.nest("/api/observers", observer_routes(observer_state));
if let Some(ref runtime) = self.observer_runtime {
info!(
path = "/api/observers",
"Observer management and runtime health endpoints enabled"
);
let runtime_state = RuntimeHealthState {
runtime: runtime.clone(),
};
app.merge(observer_runtime_routes(runtime_state))
} else {
app
}
}
pub async fn serve(self) -> Result<()> {
let app = self.build_router();
let tls_setup = TlsSetup::new(self.config.tls.clone(), self.config.database_tls.clone())?;
info!(
bind_addr = %self.config.bind_addr,
graphql_path = %self.config.graphql_path,
tls_enabled = tls_setup.is_tls_enabled(),
"Starting FraiseQL server"
);
#[cfg(feature = "observers")]
if let Some(ref runtime) = self.observer_runtime {
info!("Starting observer runtime...");
let mut guard = runtime.write().await;
match guard.start().await {
Ok(()) => info!("Observer runtime started"),
Err(e) => {
error!("Failed to start observer runtime: {}", e);
warn!("Server will continue without observers");
},
}
drop(guard);
}
let listener = TcpListener::bind(self.config.bind_addr)
.await
.map_err(|e| ServerError::BindError(e.to_string()))?;
if tls_setup.is_tls_enabled() {
let _ = tls_setup.create_rustls_config()?;
info!(
cert_path = ?tls_setup.cert_path(),
key_path = ?tls_setup.key_path(),
mtls_required = tls_setup.is_mtls_required(),
"Server TLS configuration loaded (note: use reverse proxy for server-side TLS termination)"
);
}
info!(
postgres_ssl_mode = tls_setup.postgres_ssl_mode(),
redis_ssl = tls_setup.redis_ssl_enabled(),
clickhouse_https = tls_setup.clickhouse_https_enabled(),
elasticsearch_https = tls_setup.elasticsearch_https_enabled(),
"Database connection TLS configuration applied"
);
info!("Server listening on http://{}", self.config.bind_addr);
#[cfg(feature = "arrow")]
if let Some(flight_service) = self.flight_service {
let flight_addr = "0.0.0.0:50051".parse().expect("Valid Flight address");
info!("Arrow Flight server listening on grpc://{}", flight_addr);
let flight_server = tokio::spawn(async move {
tonic::transport::Server::builder()
.add_service(flight_service.into_server())
.serve(flight_addr)
.await
});
axum::serve(listener, app)
.with_graceful_shutdown(async move {
Self::shutdown_signal().await;
#[cfg(feature = "observers")]
if let Some(ref runtime) = self.observer_runtime {
info!("Shutting down observer runtime");
let mut guard = runtime.write().await;
if let Err(e) = guard.stop().await {
error!("Error stopping runtime: {}", e);
} else {
info!("Runtime stopped cleanly");
}
}
})
.await
.map_err(|e| ServerError::IoError(std::io::Error::other(e)))?;
flight_server.abort();
}
#[cfg(not(feature = "arrow"))]
{
axum::serve(listener, app)
.with_graceful_shutdown(async move {
Self::shutdown_signal().await;
#[cfg(feature = "observers")]
if let Some(ref runtime) = self.observer_runtime {
info!("Shutting down observer runtime");
let mut guard = runtime.write().await;
if let Err(e) = guard.stop().await {
error!("Error stopping runtime: {}", e);
} else {
info!("Runtime stopped cleanly");
}
}
})
.await
.map_err(|e| ServerError::IoError(std::io::Error::other(e)))?;
}
Ok(())
}
async fn shutdown_signal() {
use tokio::signal;
let ctrl_c = async {
signal::ctrl_c().await.expect("Failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("Failed to install SIGTERM handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => info!("Received Ctrl+C"),
_ = terminate => info!("Received SIGTERM"),
}
}
}