1use std::sync::Arc;
4
5use axum::{
6 Router, middleware,
7 routing::{get, post},
8};
9#[cfg(feature = "arrow")]
10use fraiseql_arrow::FraiseQLFlightService;
11use fraiseql_core::{
12 db::traits::DatabaseAdapter,
13 runtime::{Executor, SubscriptionManager},
14 schema::CompiledSchema,
15 security::OidcValidator,
16};
17use tokio::net::TcpListener;
18#[cfg(feature = "observers")]
19use tracing::error;
20use tracing::{info, warn};
21#[cfg(feature = "observers")]
22use {
23 crate::observers::{ObserverRuntime, ObserverRuntimeConfig},
24 tokio::sync::RwLock,
25};
26
27use crate::{
28 Result, ServerError,
29 middleware::{
30 BearerAuthState, OidcAuthState, RateLimiter, bearer_auth_middleware, cors_layer_restricted,
31 metrics_middleware, oidc_auth_middleware, trace_layer,
32 },
33 routes::{
34 PlaygroundState, SubscriptionState, api, graphql::AppState, graphql_get_handler,
35 graphql_handler, health_handler, introspection_handler, metrics_handler,
36 metrics_json_handler, playground_handler, subscription_handler,
37 },
38 server_config::ServerConfig,
39 tls::TlsSetup,
40};
41
42pub struct Server<A: DatabaseAdapter> {
44 config: ServerConfig,
45 executor: Arc<Executor<A>>,
46 subscription_manager: Arc<SubscriptionManager>,
47 oidc_validator: Option<Arc<OidcValidator>>,
48 rate_limiter: Option<Arc<RateLimiter>>,
49
50 #[cfg(feature = "observers")]
51 observer_runtime: Option<Arc<RwLock<ObserverRuntime>>>,
52
53 #[cfg(feature = "observers")]
54 db_pool: Option<sqlx::PgPool>,
55
56 #[cfg(feature = "arrow")]
57 flight_service: Option<FraiseQLFlightService>,
58}
59
60impl<A: DatabaseAdapter + Clone + Send + Sync + 'static> Server<A> {
61 pub async fn new(
86 config: ServerConfig,
87 schema: CompiledSchema,
88 adapter: Arc<A>,
89 #[allow(unused_variables)] db_pool: Option<sqlx::PgPool>,
90 ) -> Result<Self> {
91 let executor = Arc::new(Executor::new(schema.clone(), adapter));
92 let subscription_manager = Arc::new(SubscriptionManager::new(Arc::new(schema)));
93
94 let oidc_validator = if let Some(ref auth_config) = config.auth {
96 info!(
97 issuer = %auth_config.issuer,
98 "Initializing OIDC authentication"
99 );
100 let validator = OidcValidator::new(auth_config.clone())
101 .await
102 .map_err(|e| ServerError::ConfigError(format!("Failed to initialize OIDC: {e}")))?;
103 Some(Arc::new(validator))
104 } else {
105 None
106 };
107
108 let rate_limiter = if let Some(ref rate_config) = config.rate_limiting {
110 if rate_config.enabled {
111 info!(
112 rps_per_ip = rate_config.rps_per_ip,
113 rps_per_user = rate_config.rps_per_user,
114 "Initializing rate limiting"
115 );
116 let limiter_config = crate::middleware::RateLimitConfig {
117 enabled: true,
118 rps_per_ip: rate_config.rps_per_ip,
119 rps_per_user: rate_config.rps_per_user,
120 burst_size: rate_config.burst_size,
121 cleanup_interval_secs: rate_config.cleanup_interval_secs,
122 };
123 Some(Arc::new(RateLimiter::new(limiter_config)))
124 } else {
125 info!("Rate limiting disabled by configuration");
126 None
127 }
128 } else {
129 None
130 };
131
132 #[cfg(feature = "observers")]
134 let observer_runtime = Self::init_observer_runtime(&config, db_pool.as_ref()).await;
135
136 #[cfg(feature = "arrow")]
138 let flight_service = {
139 let mut service = FraiseQLFlightService::new();
140 if let Some(ref validator) = oidc_validator {
141 info!("Enabling OIDC authentication for Arrow Flight");
142 service.set_oidc_validator(validator.clone());
143 } else {
144 info!("Arrow Flight initialized without authentication (dev mode)");
145 }
146 Some(service)
147 };
148
149 Ok(Self {
150 config,
151 executor,
152 subscription_manager,
153 oidc_validator,
154 rate_limiter,
155 #[cfg(feature = "observers")]
156 observer_runtime,
157 #[cfg(feature = "observers")]
158 db_pool,
159 #[cfg(feature = "arrow")]
160 flight_service,
161 })
162 }
163
164 #[cfg(feature = "arrow")]
180 pub async fn with_flight_service(
181 config: ServerConfig,
182 schema: CompiledSchema,
183 adapter: Arc<A>,
184 #[allow(unused_variables)] db_pool: Option<sqlx::PgPool>,
185 flight_service: Option<FraiseQLFlightService>,
186 ) -> Result<Self> {
187 let executor = Arc::new(Executor::new(schema.clone(), adapter));
188 let subscription_manager = Arc::new(SubscriptionManager::new(Arc::new(schema)));
189
190 let oidc_validator = if let Some(ref auth_config) = config.auth {
192 info!(
193 issuer = %auth_config.issuer,
194 "Initializing OIDC authentication"
195 );
196 let validator = OidcValidator::new(auth_config.clone())
197 .await
198 .map_err(|e| ServerError::ConfigError(format!("Failed to initialize OIDC: {e}")))?;
199 Some(Arc::new(validator))
200 } else {
201 None
202 };
203
204 let rate_limiter = if let Some(ref rate_config) = config.rate_limiting {
206 if rate_config.enabled {
207 info!(
208 rps_per_ip = rate_config.rps_per_ip,
209 rps_per_user = rate_config.rps_per_user,
210 "Initializing rate limiting"
211 );
212 let limiter_config = crate::middleware::RateLimitConfig {
213 enabled: true,
214 rps_per_ip: rate_config.rps_per_ip,
215 rps_per_user: rate_config.rps_per_user,
216 burst_size: rate_config.burst_size,
217 cleanup_interval_secs: rate_config.cleanup_interval_secs,
218 };
219 Some(Arc::new(RateLimiter::new(limiter_config)))
220 } else {
221 info!("Rate limiting disabled by configuration");
222 None
223 }
224 } else {
225 None
226 };
227
228 #[cfg(feature = "observers")]
230 let observer_runtime = Self::init_observer_runtime(&config, db_pool.as_ref()).await;
231
232 Ok(Self {
233 config,
234 executor,
235 subscription_manager,
236 oidc_validator,
237 rate_limiter,
238 #[cfg(feature = "observers")]
239 observer_runtime,
240 #[cfg(feature = "observers")]
241 db_pool,
242 flight_service,
243 })
244 }
245
246 #[cfg(feature = "observers")]
248 async fn init_observer_runtime(
249 config: &ServerConfig,
250 pool: Option<&sqlx::PgPool>,
251 ) -> Option<Arc<RwLock<ObserverRuntime>>> {
252 let observer_config = match &config.observers {
254 Some(cfg) if cfg.enabled => cfg,
255 _ => {
256 info!("Observer runtime disabled");
257 return None;
258 },
259 };
260
261 let pool = match pool {
262 Some(p) => p,
263 None => {
264 warn!("No database pool provided for observers");
265 return None;
266 },
267 };
268
269 info!("Initializing observer runtime");
270
271 let runtime_config = ObserverRuntimeConfig::new(pool.clone())
272 .with_poll_interval(observer_config.poll_interval_ms)
273 .with_batch_size(observer_config.batch_size)
274 .with_channel_capacity(observer_config.channel_capacity);
275
276 let runtime = ObserverRuntime::new(runtime_config);
277 Some(Arc::new(RwLock::new(runtime)))
278 }
279
280 fn build_router(&self) -> Router {
282 let state = AppState::new(self.executor.clone());
283 let metrics = state.metrics.clone();
284
285 let graphql_router = if let Some(ref validator) = self.oidc_validator {
288 info!(
289 graphql_path = %self.config.graphql_path,
290 "GraphQL endpoint protected by OIDC authentication (GET and POST)"
291 );
292 let auth_state = OidcAuthState::new(validator.clone());
293 Router::new()
294 .route(
295 &self.config.graphql_path,
296 get(graphql_get_handler::<A>).post(graphql_handler::<A>),
297 )
298 .route_layer(middleware::from_fn_with_state(auth_state, oidc_auth_middleware))
299 .with_state(state.clone())
300 } else {
301 Router::new()
302 .route(
303 &self.config.graphql_path,
304 get(graphql_get_handler::<A>).post(graphql_handler::<A>),
305 )
306 .with_state(state.clone())
307 };
308
309 let mut app = Router::new()
311 .route(&self.config.health_path, get(health_handler::<A>))
312 .with_state(state.clone())
313 .merge(graphql_router);
314
315 if self.config.playground_enabled {
317 let playground_state =
318 PlaygroundState::new(self.config.graphql_path.clone(), self.config.playground_tool);
319 info!(
320 playground_path = %self.config.playground_path,
321 playground_tool = ?self.config.playground_tool,
322 "GraphQL playground enabled"
323 );
324 let playground_router = Router::new()
325 .route(&self.config.playground_path, get(playground_handler))
326 .with_state(playground_state);
327 app = app.merge(playground_router);
328 }
329
330 if self.config.subscriptions_enabled {
332 let subscription_state = SubscriptionState::new(self.subscription_manager.clone());
333 info!(
334 subscription_path = %self.config.subscription_path,
335 "GraphQL subscriptions enabled (graphql-ws protocol)"
336 );
337 let subscription_router = Router::new()
338 .route(&self.config.subscription_path, get(subscription_handler))
339 .with_state(subscription_state);
340 app = app.merge(subscription_router);
341 }
342
343 if self.config.introspection_enabled {
345 if self.config.introspection_require_auth {
346 if let Some(ref validator) = self.oidc_validator {
347 info!(
348 introspection_path = %self.config.introspection_path,
349 "Introspection endpoint enabled (OIDC auth required)"
350 );
351 let auth_state = OidcAuthState::new(validator.clone());
352 let introspection_router = Router::new()
353 .route(&self.config.introspection_path, get(introspection_handler::<A>))
354 .route_layer(middleware::from_fn_with_state(
355 auth_state.clone(),
356 oidc_auth_middleware,
357 ))
358 .with_state(state.clone());
359 app = app.merge(introspection_router);
360
361 let schema_router = Router::new()
363 .route("/api/v1/schema.graphql", get(api::schema::export_sdl_handler::<A>))
364 .route("/api/v1/schema.json", get(api::schema::export_json_handler::<A>))
365 .route_layer(middleware::from_fn_with_state(
366 auth_state,
367 oidc_auth_middleware,
368 ))
369 .with_state(state.clone());
370 app = app.merge(schema_router);
371 } else {
372 warn!(
373 "introspection_require_auth is true but no OIDC configured - introspection and schema export disabled"
374 );
375 }
376 } else {
377 info!(
378 introspection_path = %self.config.introspection_path,
379 "Introspection endpoint enabled (no auth required - USE ONLY IN DEVELOPMENT)"
380 );
381 let introspection_router = Router::new()
382 .route(&self.config.introspection_path, get(introspection_handler::<A>))
383 .with_state(state.clone());
384 app = app.merge(introspection_router);
385
386 let schema_router = Router::new()
389 .route("/api/v1/schema.graphql", get(api::schema::export_sdl_handler::<A>))
390 .route("/api/v1/schema.json", get(api::schema::export_json_handler::<A>))
391 .with_state(state.clone());
392 app = app.merge(schema_router);
393 }
394 }
395
396 if self.config.metrics_enabled {
398 if let Some(ref token) = self.config.metrics_token {
399 info!(
400 metrics_path = %self.config.metrics_path,
401 metrics_json_path = %self.config.metrics_json_path,
402 "Metrics endpoints enabled (bearer token required)"
403 );
404
405 let auth_state = BearerAuthState::new(token.clone());
406
407 let metrics_router = Router::new()
410 .route(&self.config.metrics_path, get(metrics_handler::<A>))
411 .route(&self.config.metrics_json_path, get(metrics_json_handler::<A>))
412 .route_layer(middleware::from_fn_with_state(auth_state, bearer_auth_middleware))
413 .with_state(state.clone());
414
415 app = app.merge(metrics_router);
416 } else {
417 warn!(
418 "metrics_enabled is true but metrics_token is not set - metrics endpoints disabled"
419 );
420 }
421 }
422
423 if self.config.admin_api_enabled {
425 if let Some(ref token) = self.config.admin_token {
426 info!("Admin API endpoints enabled (bearer token required)");
427
428 let auth_state = BearerAuthState::new(token.clone());
429
430 let admin_router = Router::new()
432 .route(
433 "/api/v1/admin/reload-schema",
434 post(api::admin::reload_schema_handler::<A>),
435 )
436 .route("/api/v1/admin/cache/clear", post(api::admin::cache_clear_handler::<A>))
437 .route("/api/v1/admin/cache/stats", get(api::admin::cache_stats_handler::<A>))
438 .route("/api/v1/admin/config", get(api::admin::config_handler::<A>))
439 .route_layer(middleware::from_fn_with_state(auth_state, bearer_auth_middleware))
440 .with_state(state.clone());
441
442 app = app.merge(admin_router);
443 } else {
444 warn!(
445 "admin_api_enabled is true but admin_token is not set - admin endpoints disabled"
446 );
447 }
448 }
449
450 if self.config.design_api_require_auth {
452 if let Some(ref validator) = self.oidc_validator {
453 info!("Design audit API endpoints enabled (OIDC auth required)");
454 let auth_state = OidcAuthState::new(validator.clone());
455 let design_router = Router::new()
456 .route(
457 "/design/federation-audit",
458 post(api::design::federation_audit_handler::<A>),
459 )
460 .route("/design/cost-audit", post(api::design::cost_audit_handler::<A>))
461 .route("/design/cache-audit", post(api::design::cache_audit_handler::<A>))
462 .route("/design/auth-audit", post(api::design::auth_audit_handler::<A>))
463 .route(
464 "/design/compilation-audit",
465 post(api::design::compilation_audit_handler::<A>),
466 )
467 .route("/design/audit", post(api::design::overall_design_audit_handler::<A>))
468 .route_layer(middleware::from_fn_with_state(auth_state, oidc_auth_middleware))
469 .with_state(state.clone());
470 app = app.nest("/api/v1", design_router);
471 } else {
472 warn!(
473 "design_api_require_auth is true but no OIDC configured - design endpoints unprotected"
474 );
475 let design_router = Router::new()
477 .route(
478 "/design/federation-audit",
479 post(api::design::federation_audit_handler::<A>),
480 )
481 .route("/design/cost-audit", post(api::design::cost_audit_handler::<A>))
482 .route("/design/cache-audit", post(api::design::cache_audit_handler::<A>))
483 .route("/design/auth-audit", post(api::design::auth_audit_handler::<A>))
484 .route(
485 "/design/compilation-audit",
486 post(api::design::compilation_audit_handler::<A>),
487 )
488 .route("/design/audit", post(api::design::overall_design_audit_handler::<A>))
489 .with_state(state.clone());
490 app = app.nest("/api/v1", design_router);
491 }
492 } else {
493 info!("Design audit API endpoints enabled (no auth required)");
494 let design_router = Router::new()
495 .route("/design/federation-audit", post(api::design::federation_audit_handler::<A>))
496 .route("/design/cost-audit", post(api::design::cost_audit_handler::<A>))
497 .route("/design/cache-audit", post(api::design::cache_audit_handler::<A>))
498 .route("/design/auth-audit", post(api::design::auth_audit_handler::<A>))
499 .route(
500 "/design/compilation-audit",
501 post(api::design::compilation_audit_handler::<A>),
502 )
503 .route("/design/audit", post(api::design::overall_design_audit_handler::<A>))
504 .with_state(state.clone());
505 app = app.nest("/api/v1", design_router);
506 }
507
508 let api_router = api::routes(state.clone());
510 app = app.nest("/api/v1", api_router);
511
512 app = app.layer(middleware::from_fn_with_state(metrics, metrics_middleware));
515
516 #[cfg(feature = "observers")]
518 {
519 app = self.add_observer_routes(app);
520 }
521
522 if self.config.tracing_enabled {
524 app = app.layer(trace_layer());
525 }
526
527 if self.config.cors_enabled {
528 let origins = if self.config.cors_origins.is_empty() {
530 tracing::warn!(
532 "CORS enabled but no origins configured. Using localhost:3000 as default. \
533 Set cors_origins in config for production."
534 );
535 vec!["http://localhost:3000".to_string()]
536 } else {
537 self.config.cors_origins.clone()
538 };
539 app = app.layer(cors_layer_restricted(origins));
540 }
541
542 if let Some(ref limiter) = self.rate_limiter {
544 use std::net::SocketAddr;
545
546 use axum::extract::ConnectInfo;
547
548 info!("Enabling rate limiting middleware");
549 let limiter_clone = limiter.clone();
550 app = app.layer(middleware::from_fn(move |ConnectInfo(addr): ConnectInfo<SocketAddr>, req, next: axum::middleware::Next| {
551 let limiter = limiter_clone.clone();
552 async move {
553 let ip = addr.ip().to_string();
554
555 if !limiter.check_ip_limit(&ip).await {
557 warn!(ip = %ip, "IP rate limit exceeded");
558 use axum::http::StatusCode;
559 use axum::response::IntoResponse;
560 return (
561 StatusCode::TOO_MANY_REQUESTS,
562 [("Content-Type", "application/json"), ("Retry-After", "60")],
563 r#"{"errors":[{"message":"Rate limit exceeded. Please retry after 60 seconds."}]}"#,
564 ).into_response();
565 }
566
567 let remaining = limiter.get_ip_remaining(&ip).await;
569 let mut response = next.run(req).await;
570
571 let headers = response.headers_mut();
573 if let Ok(limit_value) = format!("{}", limiter.config().rps_per_ip).parse() {
574 headers.insert("X-RateLimit-Limit", limit_value);
575 }
576 if let Ok(remaining_value) = format!("{}", remaining as u32).parse() {
577 headers.insert("X-RateLimit-Remaining", remaining_value);
578 }
579
580 response
581 }
582 }));
583 }
584
585 app
586 }
587
588 #[cfg(feature = "observers")]
590 fn add_observer_routes(&self, app: Router) -> Router {
591 use crate::observers::{
592 ObserverRepository, ObserverState, RuntimeHealthState, observer_routes,
593 observer_runtime_routes,
594 };
595
596 let observer_state = ObserverState {
598 repository: ObserverRepository::new(
599 self.db_pool.clone().expect("Pool required for observers"),
600 ),
601 };
602
603 let app = app.nest("/api/observers", observer_routes(observer_state));
604
605 if let Some(ref runtime) = self.observer_runtime {
607 info!(
608 path = "/api/observers",
609 "Observer management and runtime health endpoints enabled"
610 );
611
612 let runtime_state = RuntimeHealthState {
613 runtime: runtime.clone(),
614 };
615
616 app.merge(observer_runtime_routes(runtime_state))
617 } else {
618 app
619 }
620 }
621
622 pub async fn serve(self) -> Result<()> {
628 let app = self.build_router();
629
630 let tls_setup = TlsSetup::new(self.config.tls.clone(), self.config.database_tls.clone())?;
632
633 info!(
634 bind_addr = %self.config.bind_addr,
635 graphql_path = %self.config.graphql_path,
636 tls_enabled = tls_setup.is_tls_enabled(),
637 "Starting FraiseQL server"
638 );
639
640 #[cfg(feature = "observers")]
642 if let Some(ref runtime) = self.observer_runtime {
643 info!("Starting observer runtime...");
644 let mut guard = runtime.write().await;
645
646 match guard.start().await {
647 Ok(()) => info!("Observer runtime started"),
648 Err(e) => {
649 error!("Failed to start observer runtime: {}", e);
650 warn!("Server will continue without observers");
651 },
652 }
653 drop(guard);
654 }
655
656 let listener = TcpListener::bind(self.config.bind_addr)
657 .await
658 .map_err(|e| ServerError::BindError(e.to_string()))?;
659
660 if tls_setup.is_tls_enabled() {
662 let _ = tls_setup.create_rustls_config()?;
664 info!(
665 cert_path = ?tls_setup.cert_path(),
666 key_path = ?tls_setup.key_path(),
667 mtls_required = tls_setup.is_mtls_required(),
668 "Server TLS configuration loaded (note: use reverse proxy for server-side TLS termination)"
669 );
670 }
671
672 info!(
674 postgres_ssl_mode = tls_setup.postgres_ssl_mode(),
675 redis_ssl = tls_setup.redis_ssl_enabled(),
676 clickhouse_https = tls_setup.clickhouse_https_enabled(),
677 elasticsearch_https = tls_setup.elasticsearch_https_enabled(),
678 "Database connection TLS configuration applied"
679 );
680
681 info!("Server listening on http://{}", self.config.bind_addr);
682
683 #[cfg(feature = "arrow")]
685 if let Some(flight_service) = self.flight_service {
686 let flight_addr = "0.0.0.0:50051".parse().expect("Valid Flight address");
688 info!("Arrow Flight server listening on grpc://{}", flight_addr);
689
690 let flight_server = tokio::spawn(async move {
692 tonic::transport::Server::builder()
693 .add_service(flight_service.into_server())
694 .serve(flight_addr)
695 .await
696 });
697
698 axum::serve(listener, app)
700 .with_graceful_shutdown(async move {
701 Self::shutdown_signal().await;
702
703 #[cfg(feature = "observers")]
705 if let Some(ref runtime) = self.observer_runtime {
706 info!("Shutting down observer runtime");
707 let mut guard = runtime.write().await;
708 if let Err(e) = guard.stop().await {
709 error!("Error stopping runtime: {}", e);
710 } else {
711 info!("Runtime stopped cleanly");
712 }
713 }
714 })
715 .await
716 .map_err(|e| ServerError::IoError(std::io::Error::other(e)))?;
717
718 flight_server.abort();
720 }
721
722 #[cfg(not(feature = "arrow"))]
724 {
725 axum::serve(listener, app)
726 .with_graceful_shutdown(async move {
727 Self::shutdown_signal().await;
728
729 #[cfg(feature = "observers")]
731 if let Some(ref runtime) = self.observer_runtime {
732 info!("Shutting down observer runtime");
733 let mut guard = runtime.write().await;
734 if let Err(e) = guard.stop().await {
735 error!("Error stopping runtime: {}", e);
736 } else {
737 info!("Runtime stopped cleanly");
738 }
739 }
740 })
741 .await
742 .map_err(|e| ServerError::IoError(std::io::Error::other(e)))?;
743 }
744
745 Ok(())
746 }
747
748 async fn shutdown_signal() {
750 use tokio::signal;
751
752 let ctrl_c = async {
753 signal::ctrl_c().await.expect("Failed to install Ctrl+C handler");
754 };
755
756 #[cfg(unix)]
757 let terminate = async {
758 signal::unix::signal(signal::unix::SignalKind::terminate())
759 .expect("Failed to install SIGTERM handler")
760 .recv()
761 .await;
762 };
763
764 #[cfg(not(unix))]
765 let terminate = std::future::pending::<()>();
766
767 tokio::select! {
768 _ = ctrl_c => info!("Received Ctrl+C"),
769 _ = terminate => info!("Received SIGTERM"),
770 }
771 }
772}