1use std::sync::Arc;
2use std::time::Duration;
3
4use axum::{
5 Json, Router,
6 error_handling::HandleErrorLayer,
7 extract::DefaultBodyLimit,
8 http::StatusCode,
9 middleware,
10 response::IntoResponse,
11 routing::{get, post},
12};
13use serde::Serialize;
14use tower::BoxError;
15use tower::ServiceBuilder;
16use tower::limit::ConcurrencyLimitLayer;
17use tower::timeout::TimeoutLayer;
18use tower_http::cors::{Any, CorsLayer};
19
20use forge_core::cluster::NodeId;
21use forge_core::config::McpConfig;
22use forge_core::function::{JobDispatch, WorkflowDispatch};
23use opentelemetry::global;
24use opentelemetry::propagation::Extractor;
25use tracing::Instrument;
26use tracing_opentelemetry::OpenTelemetrySpanExt;
27
28use super::auth::{AuthConfig, AuthMiddleware, HmacTokenIssuer, auth_middleware};
29use super::mcp::{McpState, mcp_get_handler, mcp_post_handler};
30use super::multipart::rpc_multipart_handler;
31use super::response::{RpcError, RpcResponse};
32use super::rpc::{RpcHandler, rpc_batch_handler, rpc_function_handler, rpc_handler};
33use super::sse::{
34 SseState, sse_handler, sse_job_subscribe_handler, sse_subscribe_handler,
35 sse_unsubscribe_handler, sse_workflow_subscribe_handler,
36};
37use super::tracing::{REQUEST_ID_HEADER, SPAN_ID_HEADER, TRACE_ID_HEADER, TracingState};
38use crate::db::Database;
39use crate::function::FunctionRegistry;
40use crate::mcp::McpToolRegistry;
41use crate::realtime::{Reactor, ReactorConfig};
42
43const MAX_JSON_BODY_SIZE: usize = 1024 * 1024;
44const MAX_MULTIPART_BODY_SIZE: usize = 20 * 1024 * 1024;
45const MAX_MULTIPART_CONCURRENCY: usize = 32;
46const DEFAULT_SIGNAL_SECRET: &str = "forge-default-signal-secret";
48
49#[derive(Debug, Clone)]
51pub struct GatewayConfig {
52 pub port: u16,
54 pub max_connections: usize,
56 pub sse_max_sessions: usize,
58 pub request_timeout_secs: u64,
60 pub cors_enabled: bool,
62 pub cors_origins: Vec<String>,
64 pub auth: AuthConfig,
66 pub mcp: McpConfig,
68 pub quiet_routes: Vec<String>,
70 pub token_ttl: forge_core::AuthTokenTtl,
72 pub project_name: String,
74}
75
76impl Default for GatewayConfig {
77 fn default() -> Self {
78 Self {
79 port: 9081,
80 max_connections: 512,
81 sse_max_sessions: 10_000,
82 request_timeout_secs: 30,
83 cors_enabled: false,
84 cors_origins: Vec::new(),
85 auth: AuthConfig::default(),
86 mcp: McpConfig::default(),
87 quiet_routes: Vec::new(),
88 token_ttl: forge_core::AuthTokenTtl::default(),
89 project_name: "forge-app".to_string(),
90 }
91 }
92}
93
94#[derive(Debug, Serialize)]
96pub struct HealthResponse {
97 pub status: String,
98 pub version: String,
99}
100
101#[derive(Debug, Serialize)]
103pub struct ReadinessResponse {
104 pub ready: bool,
105 pub database: bool,
106 pub reactor: bool,
107 pub workflows: bool,
108 #[serde(skip_serializing_if = "Option::is_none")]
109 pub blocked_workflow_runs: Option<i64>,
110 pub version: String,
111}
112
113#[derive(Clone)]
115pub struct ReadinessState {
116 db_pool: sqlx::PgPool,
117 reactor: Arc<Reactor>,
118}
119
120pub struct GatewayServer {
122 config: GatewayConfig,
123 registry: FunctionRegistry,
124 db: Database,
125 reactor: Arc<Reactor>,
126 job_dispatcher: Option<Arc<dyn JobDispatch>>,
127 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
128 mcp_registry: Option<McpToolRegistry>,
129 token_ttl: forge_core::AuthTokenTtl,
130 signals_collector: Option<crate::signals::SignalsCollector>,
131}
132
133impl GatewayServer {
134 pub fn new(config: GatewayConfig, registry: FunctionRegistry, db: Database) -> Self {
136 let node_id = NodeId::new();
137 let reactor = Arc::new(Reactor::new(
138 node_id,
139 db.primary().clone(),
140 registry.clone(),
141 ReactorConfig::default(),
142 ));
143
144 let token_ttl = config.token_ttl.clone();
145 Self {
146 config,
147 registry,
148 db,
149 reactor,
150 job_dispatcher: None,
151 workflow_dispatcher: None,
152 mcp_registry: None,
153 token_ttl,
154 signals_collector: None,
155 }
156 }
157
158 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
160 self.job_dispatcher = Some(dispatcher);
161 self
162 }
163
164 pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
166 self.workflow_dispatcher = Some(dispatcher);
167 self
168 }
169
170 pub fn with_mcp_registry(mut self, registry: McpToolRegistry) -> Self {
172 self.mcp_registry = Some(registry);
173 self
174 }
175
176 pub fn with_signals_collector(mut self, collector: crate::signals::SignalsCollector) -> Self {
179 self.signals_collector = Some(collector);
180 self
181 }
182
183 pub fn reactor(&self) -> Arc<Reactor> {
185 self.reactor.clone()
186 }
187
188 pub fn oauth_router(&self) -> Option<(Router, Arc<super::oauth::OAuthState>)> {
190 if !self.config.mcp.oauth {
191 return None;
192 }
193
194 let token_issuer = HmacTokenIssuer::from_config(&self.config.auth)
195 .map(|issuer| Arc::new(issuer) as Arc<dyn forge_core::TokenIssuer>)?;
196
197 let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
198
199 let jwt_secret = self.config.auth.jwt_secret.clone().unwrap_or_default();
200
201 let oauth_state = Arc::new(super::oauth::OAuthState::new(
202 self.db.primary().clone(),
203 auth_middleware_state,
204 token_issuer,
205 self.token_ttl.access_token_secs,
206 self.token_ttl.refresh_token_days,
207 self.config.auth.is_hmac(),
208 self.config.project_name.clone(),
209 jwt_secret,
210 ));
211
212 let router = Router::new()
213 .route(
214 "/oauth/authorize",
215 get(super::oauth::oauth_authorize_get).post(super::oauth::oauth_authorize_post),
216 )
217 .route("/oauth/token", post(super::oauth::oauth_token))
218 .route("/oauth/register", post(super::oauth::oauth_register))
219 .with_state(oauth_state.clone());
220
221 Some((router, oauth_state))
222 }
223
224 pub fn router(&self) -> Router {
226 let token_issuer = HmacTokenIssuer::from_config(&self.config.auth)
227 .map(|issuer| Arc::new(issuer) as Arc<dyn forge_core::TokenIssuer>);
228
229 let mut rpc = RpcHandler::with_dispatch_and_issuer(
230 self.registry.clone(),
231 self.db.clone(),
232 self.job_dispatcher.clone(),
233 self.workflow_dispatcher.clone(),
234 token_issuer,
235 );
236 rpc.set_token_ttl(self.token_ttl.clone());
237 if let Some(collector) = &self.signals_collector {
238 let secret = self
239 .config
240 .auth
241 .jwt_secret
242 .clone()
243 .unwrap_or_else(|| DEFAULT_SIGNAL_SECRET.to_string());
244 rpc.set_signals_collector(collector.clone(), secret);
245 }
246 let rpc_handler_state = Arc::new(rpc);
247
248 let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
249
250 let cors = if self.config.cors_enabled {
256 if self.config.cors_origins.iter().any(|o| o == "*") {
257 CorsLayer::new()
259 .allow_origin(Any)
260 .allow_methods(Any)
261 .allow_headers(Any)
262 } else {
263 use axum::http::Method;
264 let origins: Vec<_> = self
265 .config
266 .cors_origins
267 .iter()
268 .filter_map(|o| o.parse().ok())
269 .collect();
270 CorsLayer::new()
271 .allow_origin(origins)
272 .allow_methods([
273 Method::GET,
274 Method::POST,
275 Method::PUT,
276 Method::DELETE,
277 Method::PATCH,
278 Method::OPTIONS,
279 ])
280 .allow_headers([
281 axum::http::header::CONTENT_TYPE,
282 axum::http::header::AUTHORIZATION,
283 axum::http::header::ACCEPT,
284 axum::http::HeaderName::from_static("x-webhook-signature"),
285 axum::http::HeaderName::from_static("x-idempotency-key"),
286 axum::http::HeaderName::from_static("x-correlation-id"),
287 axum::http::HeaderName::from_static("x-session-id"),
288 axum::http::HeaderName::from_static("x-forge-platform"),
289 ])
290 .allow_credentials(true)
291 }
292 } else {
293 CorsLayer::new()
294 };
295
296 let sse_state = Arc::new(SseState::with_config(
298 self.reactor.clone(),
299 auth_middleware_state.clone(),
300 super::sse::SseConfig {
301 max_sessions: self.config.sse_max_sessions,
302 ..Default::default()
303 },
304 ));
305
306 let readiness_state = Arc::new(ReadinessState {
308 db_pool: self.db.primary().clone(),
309 reactor: self.reactor.clone(),
310 });
311
312 let mut main_router = Router::new()
314 .route("/health", get(health_handler))
316 .route("/ready", get(readiness_handler).with_state(readiness_state))
318 .route("/rpc", post(rpc_handler))
320 .route("/rpc/batch", post(rpc_batch_handler))
322 .route("/rpc/{function}", post(rpc_function_handler))
324 .layer(DefaultBodyLimit::max(MAX_JSON_BODY_SIZE))
326 .with_state(rpc_handler_state.clone());
328
329 let multipart_router = Router::new()
331 .route("/rpc/{function}/upload", post(rpc_multipart_handler))
332 .layer(DefaultBodyLimit::max(MAX_MULTIPART_BODY_SIZE))
333 .layer(ConcurrencyLimitLayer::new(MAX_MULTIPART_CONCURRENCY))
335 .with_state(rpc_handler_state);
336
337 let sse_router = Router::new()
339 .route("/events", get(sse_handler))
340 .route("/subscribe", post(sse_subscribe_handler))
341 .route("/unsubscribe", post(sse_unsubscribe_handler))
342 .route("/subscribe-job", post(sse_job_subscribe_handler))
343 .route("/subscribe-workflow", post(sse_workflow_subscribe_handler))
344 .with_state(sse_state);
345
346 let mut mcp_router = Router::new();
347 if self.config.mcp.enabled {
348 let path = self.config.mcp.path.clone();
349 let mcp_state = Arc::new(McpState::new(
350 self.config.mcp.clone(),
351 self.mcp_registry.clone().unwrap_or_default(),
352 self.db.primary().clone(),
353 self.job_dispatcher.clone(),
354 self.workflow_dispatcher.clone(),
355 ));
356 mcp_router = mcp_router.route(
357 &path,
358 post(mcp_post_handler)
359 .get(mcp_get_handler)
360 .with_state(mcp_state),
361 );
362 }
363
364 let mut signals_router = Router::new();
366 if let Some(collector) = &self.signals_collector {
367 let signals_state = Arc::new(crate::signals::endpoints::SignalsState {
368 collector: collector.clone(),
369 pool: self.db.analytics_pool().clone(),
370 server_secret: self
371 .config
372 .auth
373 .jwt_secret
374 .clone()
375 .unwrap_or_else(|| DEFAULT_SIGNAL_SECRET.to_string()),
376 });
377 signals_router = Router::new()
378 .route(
379 "/signal/event",
380 post(crate::signals::endpoints::event_handler),
381 )
382 .route(
383 "/signal/view",
384 post(crate::signals::endpoints::view_handler),
385 )
386 .route(
387 "/signal/user",
388 post(crate::signals::endpoints::user_handler),
389 )
390 .route(
391 "/signal/report",
392 post(crate::signals::endpoints::report_handler),
393 )
394 .with_state(signals_state);
395 }
396
397 main_router = main_router
398 .merge(multipart_router)
399 .merge(sse_router)
400 .merge(mcp_router)
401 .merge(signals_router);
402
403 let service_builder = ServiceBuilder::new()
405 .layer(HandleErrorLayer::new(handle_middleware_error))
406 .layer(ConcurrencyLimitLayer::new(self.config.max_connections))
407 .layer(TimeoutLayer::new(Duration::from_secs(
408 self.config.request_timeout_secs,
409 )))
410 .layer(cors.clone())
411 .layer(middleware::from_fn_with_state(
412 auth_middleware_state,
413 auth_middleware,
414 ))
415 .layer(middleware::from_fn_with_state(
416 Arc::new(self.config.quiet_routes.clone()),
417 tracing_middleware,
418 ));
419
420 main_router.layer(service_builder)
422 }
423
424 pub fn addr(&self) -> std::net::SocketAddr {
426 std::net::SocketAddr::from(([0, 0, 0, 0], self.config.port))
427 }
428
429 pub async fn run(self) -> Result<(), std::io::Error> {
431 let addr = self.addr();
432 let router = self.router();
433
434 self.reactor
436 .start()
437 .await
438 .map_err(|e| std::io::Error::other(format!("Failed to start reactor: {}", e)))?;
439 tracing::info!("Reactor started for real-time updates");
440
441 tracing::info!("Gateway server listening on {}", addr);
442
443 let listener = tokio::net::TcpListener::bind(addr).await?;
444 axum::serve(listener, router.into_make_service()).await
445 }
446}
447
448async fn health_handler() -> Json<HealthResponse> {
450 Json(HealthResponse {
451 status: "healthy".to_string(),
452 version: env!("CARGO_PKG_VERSION").to_string(),
453 })
454}
455
456async fn readiness_handler(
458 axum::extract::State(state): axum::extract::State<Arc<ReadinessState>>,
459) -> (axum::http::StatusCode, Json<ReadinessResponse>) {
460 let db_ok = sqlx::query_scalar!("SELECT 1 as \"v!\"")
462 .fetch_one(&state.db_pool)
463 .await
464 .is_ok();
465
466 let reactor_stats = state.reactor.stats().await;
468 let reactor_ok = reactor_stats.listener_running;
469
470 let (workflows_ok, blocked_count) = if db_ok {
472 match sqlx::query_scalar!(
473 r#"SELECT COUNT(*) as "count!" FROM forge_workflow_runs WHERE status LIKE 'blocked_%'"#,
474 )
475 .fetch_one(&state.db_pool)
476 .await
477 {
478 Ok(count) => (count == 0, if count > 0 { Some(count) } else { None }),
479 Err(_) => (true, None), }
481 } else {
482 (true, None)
483 };
484
485 let ready = db_ok && reactor_ok && workflows_ok;
486 let status_code = if ready {
487 axum::http::StatusCode::OK
488 } else {
489 axum::http::StatusCode::SERVICE_UNAVAILABLE
490 };
491
492 (
493 status_code,
494 Json(ReadinessResponse {
495 ready,
496 database: db_ok,
497 reactor: reactor_ok,
498 workflows: workflows_ok,
499 blocked_workflow_runs: blocked_count,
500 version: env!("CARGO_PKG_VERSION").to_string(),
501 }),
502 )
503}
504
505async fn handle_middleware_error(err: BoxError) -> axum::response::Response {
506 let (status, code, message) = if err.is::<tower::timeout::error::Elapsed>() {
507 (StatusCode::REQUEST_TIMEOUT, "TIMEOUT", "Request timed out")
508 } else {
509 (
510 StatusCode::SERVICE_UNAVAILABLE,
511 "SERVICE_UNAVAILABLE",
512 "Server overloaded",
513 )
514 };
515 (
516 status,
517 Json(RpcResponse::error(RpcError::new(code, message))),
518 )
519 .into_response()
520}
521
522fn set_tracing_headers(response: &mut axum::response::Response, trace_id: &str, request_id: &str) {
523 if let Ok(val) = trace_id.parse() {
524 response.headers_mut().insert(TRACE_ID_HEADER, val);
525 }
526 if let Ok(val) = request_id.parse() {
527 response.headers_mut().insert(REQUEST_ID_HEADER, val);
528 }
529}
530
531struct HeaderExtractor<'a>(&'a axum::http::HeaderMap);
533
534impl<'a> Extractor for HeaderExtractor<'a> {
535 fn get(&self, key: &str) -> Option<&str> {
536 self.0.get(key).and_then(|v| v.to_str().ok())
537 }
538
539 fn keys(&self) -> Vec<&str> {
540 self.0.keys().map(|k| k.as_str()).collect()
541 }
542}
543
544async fn tracing_middleware(
550 axum::extract::State(quiet_routes): axum::extract::State<Arc<Vec<String>>>,
551 req: axum::extract::Request,
552 next: axum::middleware::Next,
553) -> axum::response::Response {
554 let headers = req.headers();
555
556 let parent_cx =
558 global::get_text_map_propagator(|propagator| propagator.extract(&HeaderExtractor(headers)));
559
560 let trace_id = headers
561 .get(TRACE_ID_HEADER)
562 .and_then(|v| v.to_str().ok())
563 .map(String::from)
564 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
565
566 let parent_span_id = headers
567 .get(SPAN_ID_HEADER)
568 .and_then(|v| v.to_str().ok())
569 .map(String::from);
570
571 let method = req.method().to_string();
572 let path = req.uri().path().to_string();
573
574 let mut tracing_state = TracingState::with_trace_id(trace_id.clone());
575 if let Some(span_id) = parent_span_id {
576 tracing_state = tracing_state.with_parent_span(span_id);
577 }
578
579 let mut req = req;
580 req.extensions_mut().insert(tracing_state.clone());
581
582 if req
583 .extensions()
584 .get::<forge_core::function::AuthContext>()
585 .is_none()
586 {
587 req.extensions_mut()
588 .insert(forge_core::function::AuthContext::unauthenticated());
589 }
590
591 let full_path = format!("/_api{}", path);
594 let is_quiet = quiet_routes.iter().any(|r| *r == full_path || *r == path);
595
596 if is_quiet {
597 let mut response = next.run(req).await;
598 set_tracing_headers(&mut response, &trace_id, &tracing_state.request_id);
599 return response;
600 }
601
602 let span = tracing::info_span!(
603 "http.request",
604 http.method = %method,
605 http.route = %path,
606 http.status_code = tracing::field::Empty,
607 trace_id = %trace_id,
608 request_id = %tracing_state.request_id,
609 );
610
611 span.set_parent(parent_cx);
614
615 let mut response = next.run(req).instrument(span.clone()).await;
616
617 let status = response.status().as_u16();
618 let elapsed = tracing_state.elapsed();
619
620 span.record("http.status_code", status);
621 let duration_ms = elapsed.as_millis() as u64;
622 match status {
623 500..=599 => tracing::error!(parent: &span, duration_ms, "Request failed"),
624 400..=499 => tracing::warn!(parent: &span, duration_ms, "Request rejected"),
625 200..=299 => tracing::info!(parent: &span, duration_ms, "Request completed"),
626 _ => tracing::trace!(parent: &span, duration_ms, "Request completed"),
627 }
628 crate::observability::record_http_request(&method, &path, status, elapsed.as_secs_f64());
629
630 set_tracing_headers(&mut response, &trace_id, &tracing_state.request_id);
631 response
632}
633
634#[cfg(test)]
635#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
636mod tests {
637 use super::*;
638
639 #[test]
640 fn test_gateway_config_default() {
641 let config = GatewayConfig::default();
642 assert_eq!(config.port, 9081);
643 assert_eq!(config.max_connections, 512);
644 assert!(!config.cors_enabled);
645 }
646
647 #[test]
648 fn test_health_response_serialization() {
649 let resp = HealthResponse {
650 status: "healthy".to_string(),
651 version: "0.1.0".to_string(),
652 };
653 let json = serde_json::to_string(&resp).unwrap();
654 assert!(json.contains("healthy"));
655 }
656}