1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use axum::{
5 Json, Router, middleware,
6 routing::{any, get, post},
7};
8use serde::Serialize;
9use tower::ServiceBuilder;
10use tower_http::cors::{Any, CorsLayer};
11
12use forge_core::cluster::NodeId;
13use forge_core::function::{JobDispatch, WorkflowDispatch};
14
15use super::auth::{AuthConfig, AuthMiddleware, auth_middleware};
16use super::metrics::{MetricsState, metrics_middleware};
17use super::rpc::{RpcHandler, rpc_function_handler, rpc_handler};
18use super::tracing::TracingState;
19use super::websocket::{WsState, ws_handler};
20use crate::function::FunctionRegistry;
21use crate::observability::ObservabilityState;
22use crate::realtime::{Reactor, ReactorConfig};
23
24#[derive(Debug, Clone)]
26pub struct GatewayConfig {
27 pub port: u16,
29 pub max_connections: usize,
31 pub request_timeout_secs: u64,
33 pub cors_enabled: bool,
35 pub cors_origins: Vec<String>,
37 pub auth: AuthConfig,
39}
40
41impl Default for GatewayConfig {
42 fn default() -> Self {
43 Self {
44 port: 8080,
45 max_connections: 10000,
46 request_timeout_secs: 30,
47 cors_enabled: true,
48 cors_origins: vec!["*".to_string()],
49 auth: AuthConfig::default(),
50 }
51 }
52}
53
54#[derive(Debug, Serialize)]
56pub struct HealthResponse {
57 pub status: String,
58 pub version: String,
59}
60
61#[derive(Debug, Serialize)]
63pub struct ReadinessResponse {
64 pub ready: bool,
65 pub database: bool,
66 pub version: String,
67}
68
69#[derive(Clone)]
71pub struct ReadinessState {
72 db_pool: sqlx::PgPool,
73}
74
75pub struct GatewayServer {
77 config: GatewayConfig,
78 registry: FunctionRegistry,
79 db_pool: sqlx::PgPool,
80 reactor: Arc<Reactor>,
81 observability: Option<ObservabilityState>,
82 job_dispatcher: Option<Arc<dyn JobDispatch>>,
83 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
84}
85
86impl GatewayServer {
87 pub fn new(config: GatewayConfig, registry: FunctionRegistry, db_pool: sqlx::PgPool) -> Self {
89 let node_id = NodeId::new();
90 let reactor = Arc::new(Reactor::new(
91 node_id,
92 db_pool.clone(),
93 registry.clone(),
94 ReactorConfig::default(),
95 ));
96
97 Self {
98 config,
99 registry,
100 db_pool,
101 reactor,
102 observability: None,
103 job_dispatcher: None,
104 workflow_dispatcher: None,
105 }
106 }
107
108 pub fn with_observability(
110 config: GatewayConfig,
111 registry: FunctionRegistry,
112 db_pool: sqlx::PgPool,
113 observability: ObservabilityState,
114 ) -> Self {
115 let node_id = NodeId::new();
116 let reactor = Arc::new(Reactor::new(
117 node_id,
118 db_pool.clone(),
119 registry.clone(),
120 ReactorConfig::default(),
121 ));
122
123 Self {
124 config,
125 registry,
126 db_pool,
127 reactor,
128 observability: Some(observability),
129 job_dispatcher: None,
130 workflow_dispatcher: None,
131 }
132 }
133
134 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
136 self.job_dispatcher = Some(dispatcher);
137 self
138 }
139
140 pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
142 self.workflow_dispatcher = Some(dispatcher);
143 self
144 }
145
146 pub fn reactor(&self) -> Arc<Reactor> {
148 self.reactor.clone()
149 }
150
151 pub fn router(&self) -> Router {
153 let rpc_handler_state = Arc::new(RpcHandler::with_dispatch(
154 self.registry.clone(),
155 self.db_pool.clone(),
156 self.job_dispatcher.clone(),
157 self.workflow_dispatcher.clone(),
158 ));
159
160 let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
161
162 let cors = if self.config.cors_enabled {
164 if self.config.cors_origins.contains(&"*".to_string()) {
165 CorsLayer::new()
166 .allow_origin(Any)
167 .allow_methods(Any)
168 .allow_headers(Any)
169 } else {
170 let origins: Vec<_> = self
171 .config
172 .cors_origins
173 .iter()
174 .filter_map(|o| o.parse().ok())
175 .collect();
176 CorsLayer::new()
177 .allow_origin(origins)
178 .allow_methods(Any)
179 .allow_headers(Any)
180 }
181 } else {
182 CorsLayer::new()
183 };
184
185 let node_id = self.reactor.node_id();
187 let ws_state = Arc::new(WsState::new(
188 self.reactor.clone(),
189 self.db_pool.clone(),
190 node_id,
191 ));
192
193 let readiness_state = Arc::new(ReadinessState {
195 db_pool: self.db_pool.clone(),
196 });
197
198 let mut main_router = Router::new()
200 .route("/health", get(health_handler))
202 .route("/ready", get(readiness_handler).with_state(readiness_state))
204 .route("/rpc", post(rpc_handler))
206 .route("/rpc/{function}", post(rpc_function_handler))
208 .with_state(rpc_handler_state);
210
211 let service_builder = ServiceBuilder::new()
213 .layer(cors.clone())
214 .layer(middleware::from_fn_with_state(
215 auth_middleware_state,
216 auth_middleware,
217 ))
218 .layer(middleware::from_fn(tracing_middleware));
219
220 if let Some(ref observability) = self.observability {
222 let metrics_state = Arc::new(MetricsState::new(observability.clone()));
223 main_router = main_router.layer(middleware::from_fn_with_state(
224 metrics_state,
225 metrics_middleware,
226 ));
227 }
228
229 main_router = main_router.layer(service_builder);
231
232 let ws_router = Router::new()
234 .route("/ws", any(ws_handler).with_state(ws_state))
235 .layer(cors);
236
237 main_router.merge(ws_router)
239 }
240
241 pub fn addr(&self) -> SocketAddr {
243 SocketAddr::from(([0, 0, 0, 0], self.config.port))
244 }
245
246 pub async fn run(self) -> Result<(), std::io::Error> {
248 let addr = self.addr();
249 let router = self.router();
250
251 if let Err(e) = self.reactor.start().await {
253 tracing::error!("Failed to start reactor: {}", e);
254 } else {
255 tracing::info!("Reactor started for real-time updates");
256 }
257
258 tracing::info!("Gateway server listening on {}", addr);
259
260 let listener = tokio::net::TcpListener::bind(addr).await?;
261 axum::serve(listener, router).await
262 }
263}
264
265async fn health_handler() -> Json<HealthResponse> {
267 Json(HealthResponse {
268 status: "healthy".to_string(),
269 version: env!("CARGO_PKG_VERSION").to_string(),
270 })
271}
272
273async fn readiness_handler(
275 axum::extract::State(state): axum::extract::State<Arc<ReadinessState>>,
276) -> (axum::http::StatusCode, Json<ReadinessResponse>) {
277 let db_ok = sqlx::query("SELECT 1")
279 .fetch_one(&state.db_pool)
280 .await
281 .is_ok();
282
283 let ready = db_ok;
284 let status_code = if ready {
285 axum::http::StatusCode::OK
286 } else {
287 axum::http::StatusCode::SERVICE_UNAVAILABLE
288 };
289
290 (
291 status_code,
292 Json(ReadinessResponse {
293 ready,
294 database: db_ok,
295 version: env!("CARGO_PKG_VERSION").to_string(),
296 }),
297 )
298}
299
300async fn tracing_middleware(
302 req: axum::extract::Request,
303 next: axum::middleware::Next,
304) -> axum::response::Response {
305 use axum::http::header::HeaderName;
306
307 let trace_id = req
309 .headers()
310 .get(HeaderName::from_static("x-trace-id"))
311 .and_then(|v| v.to_str().ok())
312 .map(String::from)
313 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
314
315 let tracing_state = TracingState::with_trace_id(trace_id.clone());
316
317 let mut req = req;
318 req.extensions_mut().insert(tracing_state.clone());
319
320 if req
322 .extensions()
323 .get::<forge_core::function::AuthContext>()
324 .is_none()
325 {
326 req.extensions_mut()
327 .insert(forge_core::function::AuthContext::unauthenticated());
328 }
329
330 let mut response = next.run(req).await;
331
332 if let Ok(val) = trace_id.parse() {
334 response.headers_mut().insert("x-trace-id", val);
335 }
336 if let Ok(val) = tracing_state.request_id.parse() {
337 response.headers_mut().insert("x-request-id", val);
338 }
339
340 response
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346
347 #[test]
348 fn test_gateway_config_default() {
349 let config = GatewayConfig::default();
350 assert_eq!(config.port, 8080);
351 assert_eq!(config.max_connections, 10000);
352 assert!(config.cors_enabled);
353 }
354
355 #[test]
356 fn test_health_response_serialization() {
357 let resp = HealthResponse {
358 status: "healthy".to_string(),
359 version: "0.1.0".to_string(),
360 };
361 let json = serde_json::to_string(&resp).unwrap();
362 assert!(json.contains("healthy"));
363 }
364}