1use std::sync::Arc;
2use std::time::Duration;
3
4use axum::{
5 Extension, Json, Router,
6 error_handling::HandleErrorLayer,
7 extract::DefaultBodyLimit,
8 http::StatusCode,
9 middleware,
10 response::IntoResponse,
11 routing::{get, post},
12};
13use serde::Serialize;
14use tower::BoxError;
15use tower::ServiceBuilder;
16use tower::limit::ConcurrencyLimitLayer;
17use tower::timeout::TimeoutLayer;
18use tower_http::cors::{Any, CorsLayer};
19
20use forge_core::cluster::NodeId;
21use forge_core::config::McpConfig;
22use forge_core::function::{JobDispatch, WorkflowDispatch};
23use opentelemetry::global;
24use opentelemetry::propagation::Extractor;
25use tracing::Instrument;
26use tracing_opentelemetry::OpenTelemetrySpanExt;
27
28use super::auth::{AuthConfig, AuthMiddleware, HmacTokenIssuer, auth_middleware};
29use super::mcp::{McpState, mcp_get_handler, mcp_post_handler};
30use super::multipart::{MultipartConfig, rpc_multipart_handler};
31use super::response::{RpcError, RpcResponse};
32use super::rpc::{RpcHandler, rpc_batch_handler, rpc_function_handler, rpc_handler};
33use super::sse::{
34 SseState, sse_handler, sse_job_subscribe_handler, sse_subscribe_handler,
35 sse_unsubscribe_handler, sse_workflow_subscribe_handler,
36};
37use super::tracing::{REQUEST_ID_HEADER, SPAN_ID_HEADER, TRACE_ID_HEADER, TracingState};
38use crate::db::Database;
39use crate::function::FunctionRegistry;
40use crate::mcp::McpToolRegistry;
41use crate::realtime::{Reactor, ReactorConfig};
42
43const DEFAULT_MAX_JSON_BODY_SIZE: usize = 1024 * 1024;
44const DEFAULT_MAX_MULTIPART_BODY_SIZE: usize = 20 * 1024 * 1024;
45const MAX_MULTIPART_CONCURRENCY: usize = 32;
46const DEFAULT_SIGNAL_SECRET: &str = "forge-default-signal-secret";
48
49#[derive(Debug, Clone)]
51pub struct GatewayConfig {
52 pub port: u16,
54 pub max_connections: usize,
56 pub sse_max_sessions: usize,
58 pub request_timeout_secs: u64,
60 pub cors_enabled: bool,
62 pub cors_origins: Vec<String>,
64 pub auth: AuthConfig,
66 pub mcp: McpConfig,
68 pub quiet_routes: Vec<String>,
70 pub token_ttl: forge_core::AuthTokenTtl,
72 pub project_name: String,
74 pub max_body_size_bytes: usize,
76}
77
78impl Default for GatewayConfig {
79 fn default() -> Self {
80 Self {
81 port: 9081,
82 max_connections: 512,
83 sse_max_sessions: 10_000,
84 request_timeout_secs: 30,
85 cors_enabled: false,
86 cors_origins: Vec::new(),
87 auth: AuthConfig::default(),
88 mcp: McpConfig::default(),
89 quiet_routes: Vec::new(),
90 token_ttl: forge_core::AuthTokenTtl::default(),
91 project_name: "forge-app".to_string(),
92 max_body_size_bytes: DEFAULT_MAX_MULTIPART_BODY_SIZE,
93 }
94 }
95}
96
97#[derive(Debug, Serialize)]
99pub struct HealthResponse {
100 pub status: String,
101 pub version: String,
102}
103
104#[derive(Debug, Serialize)]
106pub struct ReadinessResponse {
107 pub ready: bool,
108 pub database: bool,
109 pub reactor: bool,
110 pub workflows: bool,
111 #[serde(skip_serializing_if = "Option::is_none")]
112 pub blocked_workflow_runs: Option<i64>,
113 pub version: String,
114}
115
116#[derive(Clone)]
118pub struct ReadinessState {
119 db_pool: sqlx::PgPool,
120 reactor: Arc<Reactor>,
121}
122
123pub struct GatewayServer {
125 config: GatewayConfig,
126 registry: FunctionRegistry,
127 db: Database,
128 reactor: Arc<Reactor>,
129 job_dispatcher: Option<Arc<dyn JobDispatch>>,
130 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
131 mcp_registry: Option<McpToolRegistry>,
132 token_ttl: forge_core::AuthTokenTtl,
133 signals_collector: Option<crate::signals::SignalsCollector>,
134 signals_anonymize_ip: bool,
135}
136
137impl GatewayServer {
138 pub fn new(config: GatewayConfig, registry: FunctionRegistry, db: Database) -> Self {
140 let node_id = NodeId::new();
141 let reactor = Arc::new(Reactor::new(
142 node_id,
143 db.primary().clone(),
144 registry.clone(),
145 ReactorConfig::default(),
146 ));
147
148 let token_ttl = config.token_ttl.clone();
149 Self {
150 config,
151 registry,
152 db,
153 reactor,
154 job_dispatcher: None,
155 workflow_dispatcher: None,
156 mcp_registry: None,
157 token_ttl,
158 signals_collector: None,
159 signals_anonymize_ip: false,
160 }
161 }
162
163 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
165 self.job_dispatcher = Some(dispatcher);
166 self
167 }
168
169 pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
171 self.workflow_dispatcher = Some(dispatcher);
172 self
173 }
174
175 pub fn with_mcp_registry(mut self, registry: McpToolRegistry) -> Self {
177 self.mcp_registry = Some(registry);
178 self
179 }
180
181 pub fn with_signals_collector(mut self, collector: crate::signals::SignalsCollector) -> Self {
184 self.signals_collector = Some(collector);
185 self
186 }
187
188 pub fn with_signals_anonymize_ip(mut self, anonymize: bool) -> Self {
191 self.signals_anonymize_ip = anonymize;
192 self
193 }
194
195 pub fn reactor(&self) -> Arc<Reactor> {
197 self.reactor.clone()
198 }
199
200 pub fn oauth_router(&self) -> Option<(Router, Arc<super::oauth::OAuthState>)> {
202 if !self.config.mcp.oauth {
203 return None;
204 }
205
206 let token_issuer = HmacTokenIssuer::from_config(&self.config.auth)
207 .map(|issuer| Arc::new(issuer) as Arc<dyn forge_core::TokenIssuer>)?;
208
209 let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
210
211 let jwt_secret = self.config.auth.jwt_secret.clone().unwrap_or_default();
212
213 let oauth_state = Arc::new(super::oauth::OAuthState::new(
214 self.db.primary().clone(),
215 auth_middleware_state,
216 token_issuer,
217 self.token_ttl.access_token_secs,
218 self.token_ttl.refresh_token_days,
219 self.config.auth.is_hmac(),
220 self.config.project_name.clone(),
221 jwt_secret,
222 ));
223
224 let router = Router::new()
225 .route(
226 "/oauth/authorize",
227 get(super::oauth::oauth_authorize_get).post(super::oauth::oauth_authorize_post),
228 )
229 .route("/oauth/token", post(super::oauth::oauth_token))
230 .route("/oauth/register", post(super::oauth::oauth_register))
231 .with_state(oauth_state.clone());
232
233 Some((router, oauth_state))
234 }
235
236 pub fn router(&self) -> Router {
238 let token_issuer = HmacTokenIssuer::from_config(&self.config.auth)
239 .map(|issuer| Arc::new(issuer) as Arc<dyn forge_core::TokenIssuer>);
240
241 let mut rpc = RpcHandler::with_dispatch_and_issuer(
242 self.registry.clone(),
243 self.db.clone(),
244 self.job_dispatcher.clone(),
245 self.workflow_dispatcher.clone(),
246 token_issuer,
247 );
248 rpc.set_token_ttl(self.token_ttl.clone());
249 if let Some(collector) = &self.signals_collector {
250 let secret = self.config.auth.jwt_secret.clone().unwrap_or_else(|| {
251 tracing::warn!(
252 "No jwt_secret configured; using default signal secret for visitor ID hashing. \
253 Visitor IDs will be predictable. Set [auth] jwt_secret in forge.toml."
254 );
255 DEFAULT_SIGNAL_SECRET.to_string()
256 });
257 rpc.set_signals_collector(collector.clone(), secret);
258 }
259 let rpc_handler_state = Arc::new(rpc);
260
261 let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
262
263 let cors = if self.config.cors_enabled {
269 if self.config.cors_origins.iter().any(|o| o == "*") {
270 CorsLayer::new()
272 .allow_origin(Any)
273 .allow_methods(Any)
274 .allow_headers(Any)
275 } else {
276 use axum::http::Method;
277 let origins: Vec<_> = self
278 .config
279 .cors_origins
280 .iter()
281 .filter_map(|o| o.parse().ok())
282 .collect();
283 CorsLayer::new()
284 .allow_origin(origins)
285 .allow_methods([
286 Method::GET,
287 Method::POST,
288 Method::PUT,
289 Method::DELETE,
290 Method::PATCH,
291 Method::OPTIONS,
292 ])
293 .allow_headers([
294 axum::http::header::CONTENT_TYPE,
295 axum::http::header::AUTHORIZATION,
296 axum::http::header::ACCEPT,
297 axum::http::HeaderName::from_static("x-webhook-signature"),
298 axum::http::HeaderName::from_static("x-idempotency-key"),
299 axum::http::HeaderName::from_static("x-correlation-id"),
300 axum::http::HeaderName::from_static("x-session-id"),
301 axum::http::HeaderName::from_static("x-forge-platform"),
302 ])
303 .allow_credentials(true)
304 }
305 } else {
306 CorsLayer::new()
307 };
308
309 let sse_state = Arc::new(SseState::with_config(
311 self.reactor.clone(),
312 auth_middleware_state.clone(),
313 super::sse::SseConfig {
314 max_sessions: self.config.sse_max_sessions,
315 ..Default::default()
316 },
317 ));
318
319 let readiness_state = Arc::new(ReadinessState {
321 db_pool: self.db.primary().clone(),
322 reactor: self.reactor.clone(),
323 });
324
325 let mut main_router = Router::new()
327 .route("/health", get(health_handler))
329 .route("/ready", get(readiness_handler).with_state(readiness_state))
331 .route("/rpc", post(rpc_handler))
333 .route("/rpc/batch", post(rpc_batch_handler))
335 .route("/rpc/{function}", post(rpc_function_handler))
337 .layer(DefaultBodyLimit::max(DEFAULT_MAX_JSON_BODY_SIZE))
339 .with_state(rpc_handler_state.clone());
341
342 let max_per_mutation = self
347 .registry
348 .functions()
349 .filter_map(|(_, entry)| entry.info().max_upload_size_bytes)
350 .max()
351 .unwrap_or(0);
352 let layer_limit = self.config.max_body_size_bytes.max(max_per_mutation);
353 let mp_config = MultipartConfig {
354 max_body_size_bytes: self.config.max_body_size_bytes,
355 };
356 let multipart_router = Router::new()
357 .route("/rpc/{function}/upload", post(rpc_multipart_handler))
358 .layer(DefaultBodyLimit::max(layer_limit))
359 .layer(Extension(mp_config))
360 .layer(ConcurrencyLimitLayer::new(MAX_MULTIPART_CONCURRENCY))
362 .with_state(rpc_handler_state);
363
364 let sse_router = Router::new()
366 .route("/events", get(sse_handler))
367 .route("/subscribe", post(sse_subscribe_handler))
368 .route("/unsubscribe", post(sse_unsubscribe_handler))
369 .route("/subscribe-job", post(sse_job_subscribe_handler))
370 .route("/subscribe-workflow", post(sse_workflow_subscribe_handler))
371 .with_state(sse_state);
372
373 let mut mcp_router = Router::new();
374 if self.config.mcp.enabled {
375 let path = self.config.mcp.path.clone();
376 let mcp_state = Arc::new(McpState::new(
377 self.config.mcp.clone(),
378 self.mcp_registry.clone().unwrap_or_default(),
379 self.db.primary().clone(),
380 self.job_dispatcher.clone(),
381 self.workflow_dispatcher.clone(),
382 ));
383 mcp_router = mcp_router.route(
384 &path,
385 post(mcp_post_handler)
386 .get(mcp_get_handler)
387 .with_state(mcp_state),
388 );
389 }
390
391 let mut signals_router = Router::new();
393 if let Some(collector) = &self.signals_collector {
394 let signals_state = Arc::new(crate::signals::endpoints::SignalsState {
395 collector: collector.clone(),
396 pool: self.db.analytics_pool().clone(),
397 server_secret: self
398 .config
399 .auth
400 .jwt_secret
401 .clone()
402 .unwrap_or_else(|| {
403 tracing::warn!(
404 "No jwt_secret configured; using default signal secret for visitor ID hashing. \
405 Visitor IDs will be predictable. Set [auth] jwt_secret in forge.toml."
406 );
407 DEFAULT_SIGNAL_SECRET.to_string()
408 }),
409 anonymize_ip: self.signals_anonymize_ip,
410 });
411 signals_router = Router::new()
412 .route(
413 "/signal/event",
414 post(crate::signals::endpoints::event_handler),
415 )
416 .route(
417 "/signal/view",
418 post(crate::signals::endpoints::view_handler),
419 )
420 .route(
421 "/signal/user",
422 post(crate::signals::endpoints::user_handler),
423 )
424 .route(
425 "/signal/report",
426 post(crate::signals::endpoints::report_handler),
427 )
428 .with_state(signals_state);
429 }
430
431 main_router = main_router
432 .merge(multipart_router)
433 .merge(sse_router)
434 .merge(mcp_router)
435 .merge(signals_router);
436
437 let service_builder = ServiceBuilder::new()
439 .layer(HandleErrorLayer::new(handle_middleware_error))
440 .layer(ConcurrencyLimitLayer::new(self.config.max_connections))
441 .layer(TimeoutLayer::new(Duration::from_secs(
442 self.config.request_timeout_secs,
443 )))
444 .layer(cors.clone())
445 .layer(middleware::from_fn_with_state(
446 auth_middleware_state,
447 auth_middleware,
448 ))
449 .layer(middleware::from_fn_with_state(
450 Arc::new(self.config.quiet_routes.clone()),
451 tracing_middleware,
452 ));
453
454 main_router.layer(service_builder)
456 }
457
458 pub fn addr(&self) -> std::net::SocketAddr {
460 std::net::SocketAddr::from(([0, 0, 0, 0], self.config.port))
461 }
462
463 pub async fn run(self) -> Result<(), std::io::Error> {
465 let addr = self.addr();
466 let router = self.router();
467
468 self.reactor
470 .start()
471 .await
472 .map_err(|e| std::io::Error::other(format!("Failed to start reactor: {}", e)))?;
473 tracing::info!("Reactor started for real-time updates");
474
475 tracing::info!("Gateway server listening on {}", addr);
476
477 let listener = tokio::net::TcpListener::bind(addr).await?;
478 axum::serve(listener, router.into_make_service()).await
479 }
480}
481
482async fn health_handler() -> Json<HealthResponse> {
484 Json(HealthResponse {
485 status: "healthy".to_string(),
486 version: env!("CARGO_PKG_VERSION").to_string(),
487 })
488}
489
490async fn readiness_handler(
492 axum::extract::State(state): axum::extract::State<Arc<ReadinessState>>,
493) -> (axum::http::StatusCode, Json<ReadinessResponse>) {
494 let db_ok = sqlx::query_scalar!("SELECT 1 as \"v!\"")
496 .fetch_one(&state.db_pool)
497 .await
498 .is_ok();
499
500 let reactor_stats = state.reactor.stats().await;
502 let reactor_ok = reactor_stats.listener_running;
503
504 let (workflows_ok, blocked_count) = if db_ok {
506 match sqlx::query_scalar!(
507 r#"SELECT COUNT(*) as "count!" FROM forge_workflow_runs WHERE status LIKE 'blocked_%'"#,
508 )
509 .fetch_one(&state.db_pool)
510 .await
511 {
512 Ok(count) => (count == 0, if count > 0 { Some(count) } else { None }),
513 Err(_) => (true, None), }
515 } else {
516 (true, None)
517 };
518
519 let ready = db_ok && reactor_ok && workflows_ok;
520 let status_code = if ready {
521 axum::http::StatusCode::OK
522 } else {
523 axum::http::StatusCode::SERVICE_UNAVAILABLE
524 };
525
526 (
527 status_code,
528 Json(ReadinessResponse {
529 ready,
530 database: db_ok,
531 reactor: reactor_ok,
532 workflows: workflows_ok,
533 blocked_workflow_runs: blocked_count,
534 version: env!("CARGO_PKG_VERSION").to_string(),
535 }),
536 )
537}
538
539async fn handle_middleware_error(err: BoxError) -> axum::response::Response {
540 let (status, code, message) = if err.is::<tower::timeout::error::Elapsed>() {
541 (StatusCode::REQUEST_TIMEOUT, "TIMEOUT", "Request timed out")
542 } else {
543 (
544 StatusCode::SERVICE_UNAVAILABLE,
545 "SERVICE_UNAVAILABLE",
546 "Server overloaded",
547 )
548 };
549 (
550 status,
551 Json(RpcResponse::error(RpcError::new(code, message))),
552 )
553 .into_response()
554}
555
556fn set_tracing_headers(response: &mut axum::response::Response, trace_id: &str, request_id: &str) {
557 if let Ok(val) = trace_id.parse() {
558 response.headers_mut().insert(TRACE_ID_HEADER, val);
559 }
560 if let Ok(val) = request_id.parse() {
561 response.headers_mut().insert(REQUEST_ID_HEADER, val);
562 }
563}
564
565struct HeaderExtractor<'a>(&'a axum::http::HeaderMap);
567
568impl<'a> Extractor for HeaderExtractor<'a> {
569 fn get(&self, key: &str) -> Option<&str> {
570 self.0.get(key).and_then(|v| v.to_str().ok())
571 }
572
573 fn keys(&self) -> Vec<&str> {
574 self.0.keys().map(|k| k.as_str()).collect()
575 }
576}
577
578async fn tracing_middleware(
584 axum::extract::State(quiet_routes): axum::extract::State<Arc<Vec<String>>>,
585 req: axum::extract::Request,
586 next: axum::middleware::Next,
587) -> axum::response::Response {
588 let headers = req.headers();
589
590 let parent_cx =
592 global::get_text_map_propagator(|propagator| propagator.extract(&HeaderExtractor(headers)));
593
594 let trace_id = headers
595 .get(TRACE_ID_HEADER)
596 .and_then(|v| v.to_str().ok())
597 .map(String::from)
598 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
599
600 let parent_span_id = headers
601 .get(SPAN_ID_HEADER)
602 .and_then(|v| v.to_str().ok())
603 .map(String::from);
604
605 let method = req.method().to_string();
606 let path = req.uri().path().to_string();
607
608 let mut tracing_state = TracingState::with_trace_id(trace_id.clone());
609 if let Some(span_id) = parent_span_id {
610 tracing_state = tracing_state.with_parent_span(span_id);
611 }
612
613 let mut req = req;
614 req.extensions_mut().insert(tracing_state.clone());
615
616 if req
617 .extensions()
618 .get::<forge_core::function::AuthContext>()
619 .is_none()
620 {
621 req.extensions_mut()
622 .insert(forge_core::function::AuthContext::unauthenticated());
623 }
624
625 let full_path = format!("/_api{}", path);
628 let is_quiet = quiet_routes.iter().any(|r| *r == full_path || *r == path);
629
630 if is_quiet {
631 let mut response = next.run(req).await;
632 set_tracing_headers(&mut response, &trace_id, &tracing_state.request_id);
633 return response;
634 }
635
636 let span = tracing::info_span!(
637 "http.request",
638 http.method = %method,
639 http.route = %path,
640 http.status_code = tracing::field::Empty,
641 trace_id = %trace_id,
642 request_id = %tracing_state.request_id,
643 );
644
645 span.set_parent(parent_cx);
648
649 let mut response = next.run(req).instrument(span.clone()).await;
650
651 let status = response.status().as_u16();
652 let elapsed = tracing_state.elapsed();
653
654 span.record("http.status_code", status);
655 let duration_ms = elapsed.as_millis() as u64;
656 match status {
657 500..=599 => tracing::error!(parent: &span, duration_ms, "Request failed"),
658 400..=499 => tracing::warn!(parent: &span, duration_ms, "Request rejected"),
659 200..=299 => tracing::info!(parent: &span, duration_ms, "Request completed"),
660 _ => tracing::trace!(parent: &span, duration_ms, "Request completed"),
661 }
662 crate::observability::record_http_request(&method, &path, status, elapsed.as_secs_f64());
663
664 set_tracing_headers(&mut response, &trace_id, &tracing_state.request_id);
665 response
666}
667
668#[cfg(test)]
669#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
670mod tests {
671 use super::*;
672
673 #[test]
674 fn test_gateway_config_default() {
675 let config = GatewayConfig::default();
676 assert_eq!(config.port, 9081);
677 assert_eq!(config.max_connections, 512);
678 assert!(!config.cors_enabled);
679 }
680
681 #[test]
682 fn test_health_response_serialization() {
683 let resp = HealthResponse {
684 status: "healthy".to_string(),
685 version: "0.1.0".to_string(),
686 };
687 let json = serde_json::to_string(&resp).unwrap();
688 assert!(json.contains("healthy"));
689 }
690}