1use std::sync::Arc;
2
3use axum::{
4 Json, Router, middleware,
5 routing::{any, get, post},
6};
7use serde::Serialize;
8use tower::ServiceBuilder;
9use tower_http::cors::{Any, CorsLayer};
10
11use forge_core::cluster::NodeId;
12use forge_core::function::{JobDispatch, WorkflowDispatch};
13
14use super::auth::{AuthConfig, AuthMiddleware, auth_middleware};
15use super::metrics::{MetricsState, metrics_middleware};
16use super::rpc::{RpcHandler, rpc_function_handler, rpc_handler};
17use super::tracing::TracingState;
18use super::websocket::{WsState, ws_handler};
19use crate::function::FunctionRegistry;
20use crate::observability::ObservabilityState;
21use crate::realtime::{Reactor, ReactorConfig};
22
23#[derive(Debug, Clone)]
25pub struct GatewayConfig {
26 pub port: u16,
28 pub max_connections: usize,
30 pub request_timeout_secs: u64,
32 pub cors_enabled: bool,
34 pub cors_origins: Vec<String>,
36 pub auth: AuthConfig,
38}
39
40impl Default for GatewayConfig {
41 fn default() -> Self {
42 Self {
43 port: 8080,
44 max_connections: 10000,
45 request_timeout_secs: 30,
46 cors_enabled: true,
47 cors_origins: vec!["*".to_string()],
48 auth: AuthConfig::default(),
49 }
50 }
51}
52
53#[derive(Debug, Serialize)]
55pub struct HealthResponse {
56 pub status: String,
57 pub version: String,
58}
59
60#[derive(Debug, Serialize)]
62pub struct ReadinessResponse {
63 pub ready: bool,
64 pub database: bool,
65 pub version: String,
66}
67
68#[derive(Clone)]
70pub struct ReadinessState {
71 db_pool: sqlx::PgPool,
72}
73
74pub struct GatewayServer {
76 config: GatewayConfig,
77 registry: FunctionRegistry,
78 db_pool: sqlx::PgPool,
79 reactor: Arc<Reactor>,
80 observability: Option<ObservabilityState>,
81 job_dispatcher: Option<Arc<dyn JobDispatch>>,
82 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
83}
84
85impl GatewayServer {
86 pub fn new(config: GatewayConfig, registry: FunctionRegistry, db_pool: sqlx::PgPool) -> Self {
88 let node_id = NodeId::new();
89 let reactor = Arc::new(Reactor::new(
90 node_id,
91 db_pool.clone(),
92 registry.clone(),
93 ReactorConfig::default(),
94 ));
95
96 Self {
97 config,
98 registry,
99 db_pool,
100 reactor,
101 observability: None,
102 job_dispatcher: None,
103 workflow_dispatcher: None,
104 }
105 }
106
107 pub fn with_observability(
109 config: GatewayConfig,
110 registry: FunctionRegistry,
111 db_pool: sqlx::PgPool,
112 observability: ObservabilityState,
113 ) -> Self {
114 let node_id = NodeId::new();
115 let reactor = Arc::new(Reactor::new(
116 node_id,
117 db_pool.clone(),
118 registry.clone(),
119 ReactorConfig::default(),
120 ));
121
122 Self {
123 config,
124 registry,
125 db_pool,
126 reactor,
127 observability: Some(observability),
128 job_dispatcher: None,
129 workflow_dispatcher: None,
130 }
131 }
132
133 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
135 self.job_dispatcher = Some(dispatcher);
136 self
137 }
138
139 pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
141 self.workflow_dispatcher = Some(dispatcher);
142 self
143 }
144
145 pub fn reactor(&self) -> Arc<Reactor> {
147 self.reactor.clone()
148 }
149
150 pub fn router(&self) -> Router {
152 let rpc_handler_state = Arc::new(RpcHandler::with_dispatch(
153 self.registry.clone(),
154 self.db_pool.clone(),
155 self.job_dispatcher.clone(),
156 self.workflow_dispatcher.clone(),
157 ));
158
159 let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
160
161 let cors = if self.config.cors_enabled {
163 if self.config.cors_origins.contains(&"*".to_string()) {
164 CorsLayer::new()
165 .allow_origin(Any)
166 .allow_methods(Any)
167 .allow_headers(Any)
168 } else {
169 let origins: Vec<_> = self
170 .config
171 .cors_origins
172 .iter()
173 .filter_map(|o| o.parse().ok())
174 .collect();
175 CorsLayer::new()
176 .allow_origin(origins)
177 .allow_methods(Any)
178 .allow_headers(Any)
179 }
180 } else {
181 CorsLayer::new()
182 };
183
184 let node_id = self.reactor.node_id();
186 let ws_state = Arc::new(WsState::with_auth(
187 self.reactor.clone(),
188 self.db_pool.clone(),
189 node_id,
190 auth_middleware_state.clone(),
191 ));
192
193 let readiness_state = Arc::new(ReadinessState {
195 db_pool: self.db_pool.clone(),
196 });
197
198 let mut main_router = Router::new()
200 .route("/health", get(health_handler))
202 .route("/ready", get(readiness_handler).with_state(readiness_state))
204 .route("/rpc", post(rpc_handler))
206 .route("/rpc/{function}", post(rpc_function_handler))
208 .with_state(rpc_handler_state);
210
211 let service_builder = ServiceBuilder::new()
213 .layer(cors.clone())
214 .layer(middleware::from_fn_with_state(
215 auth_middleware_state,
216 auth_middleware,
217 ))
218 .layer(middleware::from_fn(tracing_middleware));
219
220 if let Some(ref observability) = self.observability {
222 let metrics_state = Arc::new(MetricsState::new(observability.clone()));
223 main_router = main_router.layer(middleware::from_fn_with_state(
224 metrics_state,
225 metrics_middleware,
226 ));
227 }
228
229 main_router = main_router.layer(service_builder);
231
232 let ws_router = Router::new()
234 .route("/ws", any(ws_handler).with_state(ws_state))
235 .layer(cors);
236
237 main_router.merge(ws_router)
239 }
240
241 pub fn addr(&self) -> std::net::SocketAddr {
243 std::net::SocketAddr::from(([0, 0, 0, 0], self.config.port))
244 }
245
246 pub async fn run(self) -> Result<(), std::io::Error> {
248 let addr = self.addr();
249 let router = self.router();
250
251 if let Err(e) = self.reactor.start().await {
253 tracing::error!("Failed to start reactor: {}", e);
254 } else {
255 tracing::info!("Reactor started for real-time updates");
256 }
257
258 tracing::info!("Gateway server listening on {}", addr);
259
260 let listener = tokio::net::TcpListener::bind(addr).await?;
261 axum::serve(listener, router.into_make_service()).await
262 }
263}
264
265async fn health_handler() -> Json<HealthResponse> {
267 Json(HealthResponse {
268 status: "healthy".to_string(),
269 version: env!("CARGO_PKG_VERSION").to_string(),
270 })
271}
272
273async fn readiness_handler(
275 axum::extract::State(state): axum::extract::State<Arc<ReadinessState>>,
276) -> (axum::http::StatusCode, Json<ReadinessResponse>) {
277 let db_ok = sqlx::query("SELECT 1")
279 .fetch_one(&state.db_pool)
280 .await
281 .is_ok();
282
283 let ready = db_ok;
284 let status_code = if ready {
285 axum::http::StatusCode::OK
286 } else {
287 axum::http::StatusCode::SERVICE_UNAVAILABLE
288 };
289
290 (
291 status_code,
292 Json(ReadinessResponse {
293 ready,
294 database: db_ok,
295 version: env!("CARGO_PKG_VERSION").to_string(),
296 }),
297 )
298}
299
300async fn tracing_middleware(
302 req: axum::extract::Request,
303 next: axum::middleware::Next,
304) -> axum::response::Response {
305 use axum::http::header::HeaderName;
306
307 let trace_id = req
309 .headers()
310 .get(HeaderName::from_static("x-trace-id"))
311 .and_then(|v| v.to_str().ok())
312 .map(String::from)
313 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
314
315 let tracing_state = TracingState::with_trace_id(trace_id.clone());
316
317 let mut req = req;
318 req.extensions_mut().insert(tracing_state.clone());
319
320 if req
322 .extensions()
323 .get::<forge_core::function::AuthContext>()
324 .is_none()
325 {
326 req.extensions_mut()
327 .insert(forge_core::function::AuthContext::unauthenticated());
328 }
329
330 let mut response = next.run(req).await;
331
332 if let Ok(val) = trace_id.parse() {
334 response.headers_mut().insert("x-trace-id", val);
335 }
336 if let Ok(val) = tracing_state.request_id.parse() {
337 response.headers_mut().insert("x-request-id", val);
338 }
339
340 response
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346
347 #[test]
348 fn test_gateway_config_default() {
349 let config = GatewayConfig::default();
350 assert_eq!(config.port, 8080);
351 assert_eq!(config.max_connections, 10000);
352 assert!(config.cors_enabled);
353 }
354
355 #[test]
356 fn test_health_response_serialization() {
357 let resp = HealthResponse {
358 status: "healthy".to_string(),
359 version: "0.1.0".to_string(),
360 };
361 let json = serde_json::to_string(&resp).unwrap();
362 assert!(json.contains("healthy"));
363 }
364}