1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use axum::{
5 middleware,
6 routing::{any, get, post},
7 Json, Router,
8};
9use serde::Serialize;
10use tower::ServiceBuilder;
11use tower_http::cors::{Any, CorsLayer};
12
13use forge_core::cluster::NodeId;
14use forge_core::function::{JobDispatch, WorkflowDispatch};
15
16use super::auth::{auth_middleware, AuthConfig, AuthMiddleware};
17use super::metrics::{metrics_middleware, MetricsState};
18use super::rpc::{rpc_function_handler, rpc_handler, RpcHandler};
19use super::tracing::TracingState;
20use super::websocket::{ws_handler, WsState};
21use crate::function::FunctionRegistry;
22use crate::observability::ObservabilityState;
23use crate::realtime::{Reactor, ReactorConfig};
24
25#[derive(Debug, Clone)]
27pub struct GatewayConfig {
28 pub port: u16,
30 pub max_connections: usize,
32 pub request_timeout_secs: u64,
34 pub cors_enabled: bool,
36 pub cors_origins: Vec<String>,
38 pub auth: AuthConfig,
40}
41
42impl Default for GatewayConfig {
43 fn default() -> Self {
44 Self {
45 port: 8080,
46 max_connections: 10000,
47 request_timeout_secs: 30,
48 cors_enabled: true,
49 cors_origins: vec!["*".to_string()],
50 auth: AuthConfig::default(),
51 }
52 }
53}
54
55#[derive(Debug, Serialize)]
57pub struct HealthResponse {
58 pub status: String,
59 pub version: String,
60}
61
62#[derive(Debug, Serialize)]
64pub struct ReadinessResponse {
65 pub ready: bool,
66 pub database: bool,
67 pub version: String,
68}
69
70#[derive(Clone)]
72pub struct ReadinessState {
73 db_pool: sqlx::PgPool,
74}
75
76pub struct GatewayServer {
78 config: GatewayConfig,
79 registry: FunctionRegistry,
80 db_pool: sqlx::PgPool,
81 reactor: Arc<Reactor>,
82 observability: Option<ObservabilityState>,
83 job_dispatcher: Option<Arc<dyn JobDispatch>>,
84 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
85}
86
87impl GatewayServer {
88 pub fn new(config: GatewayConfig, registry: FunctionRegistry, db_pool: sqlx::PgPool) -> Self {
90 let node_id = NodeId::new();
91 let reactor = Arc::new(Reactor::new(
92 node_id,
93 db_pool.clone(),
94 registry.clone(),
95 ReactorConfig::default(),
96 ));
97
98 Self {
99 config,
100 registry,
101 db_pool,
102 reactor,
103 observability: None,
104 job_dispatcher: None,
105 workflow_dispatcher: None,
106 }
107 }
108
109 pub fn with_observability(
111 config: GatewayConfig,
112 registry: FunctionRegistry,
113 db_pool: sqlx::PgPool,
114 observability: ObservabilityState,
115 ) -> Self {
116 let node_id = NodeId::new();
117 let reactor = Arc::new(Reactor::new(
118 node_id,
119 db_pool.clone(),
120 registry.clone(),
121 ReactorConfig::default(),
122 ));
123
124 Self {
125 config,
126 registry,
127 db_pool,
128 reactor,
129 observability: Some(observability),
130 job_dispatcher: None,
131 workflow_dispatcher: None,
132 }
133 }
134
135 pub fn with_job_dispatcher(mut self, dispatcher: Arc<dyn JobDispatch>) -> Self {
137 self.job_dispatcher = Some(dispatcher);
138 self
139 }
140
141 pub fn with_workflow_dispatcher(mut self, dispatcher: Arc<dyn WorkflowDispatch>) -> Self {
143 self.workflow_dispatcher = Some(dispatcher);
144 self
145 }
146
147 pub fn reactor(&self) -> Arc<Reactor> {
149 self.reactor.clone()
150 }
151
152 pub fn router(&self) -> Router {
154 let rpc_handler_state = Arc::new(RpcHandler::with_dispatch(
155 self.registry.clone(),
156 self.db_pool.clone(),
157 self.job_dispatcher.clone(),
158 self.workflow_dispatcher.clone(),
159 ));
160
161 let auth_middleware_state = Arc::new(AuthMiddleware::new(self.config.auth.clone()));
162
163 let cors = if self.config.cors_enabled {
165 if self.config.cors_origins.contains(&"*".to_string()) {
166 CorsLayer::new()
167 .allow_origin(Any)
168 .allow_methods(Any)
169 .allow_headers(Any)
170 } else {
171 let origins: Vec<_> = self
172 .config
173 .cors_origins
174 .iter()
175 .filter_map(|o| o.parse().ok())
176 .collect();
177 CorsLayer::new()
178 .allow_origin(origins)
179 .allow_methods(Any)
180 .allow_headers(Any)
181 }
182 } else {
183 CorsLayer::new()
184 };
185
186 let node_id = self.reactor.node_id();
188 let ws_state = Arc::new(WsState::new(
189 self.reactor.clone(),
190 self.db_pool.clone(),
191 node_id,
192 ));
193
194 let readiness_state = Arc::new(ReadinessState {
196 db_pool: self.db_pool.clone(),
197 });
198
199 let mut main_router = Router::new()
201 .route("/health", get(health_handler))
203 .route("/ready", get(readiness_handler).with_state(readiness_state))
205 .route("/rpc", post(rpc_handler))
207 .route("/rpc/{function}", post(rpc_function_handler))
209 .with_state(rpc_handler_state);
211
212 let service_builder = ServiceBuilder::new()
214 .layer(cors.clone())
215 .layer(middleware::from_fn_with_state(
216 auth_middleware_state,
217 auth_middleware,
218 ))
219 .layer(middleware::from_fn(tracing_middleware));
220
221 if let Some(ref observability) = self.observability {
223 let metrics_state = Arc::new(MetricsState::new(observability.clone()));
224 main_router = main_router.layer(middleware::from_fn_with_state(
225 metrics_state,
226 metrics_middleware,
227 ));
228 }
229
230 main_router = main_router.layer(service_builder);
232
233 let ws_router = Router::new()
235 .route("/ws", any(ws_handler).with_state(ws_state))
236 .layer(cors);
237
238 main_router.merge(ws_router)
240 }
241
242 pub fn addr(&self) -> SocketAddr {
244 SocketAddr::from(([0, 0, 0, 0], self.config.port))
245 }
246
247 pub async fn run(self) -> Result<(), std::io::Error> {
249 let addr = self.addr();
250 let router = self.router();
251
252 if let Err(e) = self.reactor.start().await {
254 tracing::error!("Failed to start reactor: {}", e);
255 } else {
256 tracing::info!("Reactor started for real-time updates");
257 }
258
259 tracing::info!("Gateway server listening on {}", addr);
260
261 let listener = tokio::net::TcpListener::bind(addr).await?;
262 axum::serve(listener, router).await
263 }
264}
265
266async fn health_handler() -> Json<HealthResponse> {
268 Json(HealthResponse {
269 status: "healthy".to_string(),
270 version: env!("CARGO_PKG_VERSION").to_string(),
271 })
272}
273
274async fn readiness_handler(
276 axum::extract::State(state): axum::extract::State<Arc<ReadinessState>>,
277) -> (axum::http::StatusCode, Json<ReadinessResponse>) {
278 let db_ok = sqlx::query("SELECT 1")
280 .fetch_one(&state.db_pool)
281 .await
282 .is_ok();
283
284 let ready = db_ok;
285 let status_code = if ready {
286 axum::http::StatusCode::OK
287 } else {
288 axum::http::StatusCode::SERVICE_UNAVAILABLE
289 };
290
291 (
292 status_code,
293 Json(ReadinessResponse {
294 ready,
295 database: db_ok,
296 version: env!("CARGO_PKG_VERSION").to_string(),
297 }),
298 )
299}
300
301async fn tracing_middleware(
303 req: axum::extract::Request,
304 next: axum::middleware::Next,
305) -> axum::response::Response {
306 use axum::http::header::HeaderName;
307
308 let trace_id = req
310 .headers()
311 .get(HeaderName::from_static("x-trace-id"))
312 .and_then(|v| v.to_str().ok())
313 .map(String::from)
314 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
315
316 let tracing_state = TracingState::with_trace_id(trace_id.clone());
317
318 let mut req = req;
319 req.extensions_mut().insert(tracing_state.clone());
320
321 if req
323 .extensions()
324 .get::<forge_core::function::AuthContext>()
325 .is_none()
326 {
327 req.extensions_mut()
328 .insert(forge_core::function::AuthContext::unauthenticated());
329 }
330
331 let mut response = next.run(req).await;
332
333 if let Ok(val) = trace_id.parse() {
335 response.headers_mut().insert("x-trace-id", val);
336 }
337 if let Ok(val) = tracing_state.request_id.parse() {
338 response.headers_mut().insert("x-request-id", val);
339 }
340
341 response
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[test]
349 fn test_gateway_config_default() {
350 let config = GatewayConfig::default();
351 assert_eq!(config.port, 8080);
352 assert_eq!(config.max_connections, 10000);
353 assert!(config.cors_enabled);
354 }
355
356 #[test]
357 fn test_health_response_serialization() {
358 let resp = HealthResponse {
359 status: "healthy".to_string(),
360 version: "0.1.0".to_string(),
361 };
362 let json = serde_json::to_string(&resp).unwrap();
363 assert!(json.contains("healthy"));
364 }
365}