forge_runtime/gateway/
rpc.rs1use std::sync::Arc;
2
3use axum::{
4 Json,
5 extract::{Extension, State},
6 http::{HeaderMap, header::USER_AGENT},
7};
8use forge_core::function::{AuthContext, JobDispatch, RequestMetadata, WorkflowDispatch};
9
10use super::request::RpcRequest;
11use super::response::{RpcError, RpcResponse};
12use super::tracing::TracingState;
13use crate::function::{FunctionExecutor, FunctionRegistry};
14
15#[derive(Clone)]
17pub struct RpcHandler {
18 executor: Arc<FunctionExecutor>,
20}
21
22impl RpcHandler {
23 pub fn new(registry: FunctionRegistry, db_pool: sqlx::PgPool) -> Self {
25 let executor = FunctionExecutor::new(Arc::new(registry), db_pool);
26 Self {
27 executor: Arc::new(executor),
28 }
29 }
30
31 pub fn with_dispatch(
33 registry: FunctionRegistry,
34 db_pool: sqlx::PgPool,
35 job_dispatcher: Option<Arc<dyn JobDispatch>>,
36 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
37 ) -> Self {
38 let executor = FunctionExecutor::with_dispatch(
39 Arc::new(registry),
40 db_pool,
41 job_dispatcher,
42 workflow_dispatcher,
43 );
44 Self {
45 executor: Arc::new(executor),
46 }
47 }
48
49 pub async fn handle(
51 &self,
52 request: RpcRequest,
53 auth: AuthContext,
54 metadata: RequestMetadata,
55 ) -> RpcResponse {
56 match self
58 .executor
59 .execute(&request.function, request.args, auth, metadata.clone())
60 .await
61 {
62 Ok(exec_result) => RpcResponse::success(exec_result.result)
63 .with_request_id(metadata.request_id.to_string()),
64 Err(e) => RpcResponse::error(RpcError::from(e))
65 .with_request_id(metadata.request_id.to_string()),
66 }
67 }
68}
69
70fn extract_client_ip(headers: &HeaderMap) -> Option<String> {
72 headers
73 .get("x-forwarded-for")
74 .and_then(|v| v.to_str().ok())
75 .map(|s| s.split(',').next().unwrap_or("").trim().to_string())
76 .filter(|s| !s.is_empty())
77 .or_else(|| {
78 headers
79 .get("x-real-ip")
80 .and_then(|v| v.to_str().ok())
81 .map(|s| s.trim().to_string())
82 .filter(|s| !s.is_empty())
83 })
84}
85
86fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
88 headers
89 .get(USER_AGENT)
90 .and_then(|v| v.to_str().ok())
91 .map(String::from)
92}
93
94pub async fn rpc_handler(
96 State(handler): State<Arc<RpcHandler>>,
97 Extension(auth): Extension<AuthContext>,
98 Extension(tracing): Extension<TracingState>,
99 headers: HeaderMap,
100 Json(request): Json<RpcRequest>,
101) -> RpcResponse {
102 let metadata = RequestMetadata {
103 request_id: uuid::Uuid::parse_str(&tracing.request_id)
104 .unwrap_or_else(|_| uuid::Uuid::new_v4()),
105 trace_id: tracing.trace_id,
106 client_ip: extract_client_ip(&headers),
107 user_agent: extract_user_agent(&headers),
108 timestamp: chrono::Utc::now(),
109 };
110
111 handler.handle(request, auth, metadata).await
112}
113
114#[derive(Debug, serde::Deserialize)]
116pub struct RpcFunctionBody {
117 #[serde(default)]
119 pub args: serde_json::Value,
120}
121
122pub async fn rpc_function_handler(
124 State(handler): State<Arc<RpcHandler>>,
125 Extension(auth): Extension<AuthContext>,
126 Extension(tracing): Extension<TracingState>,
127 headers: HeaderMap,
128 axum::extract::Path(function): axum::extract::Path<String>,
129 Json(body): Json<RpcFunctionBody>,
130) -> RpcResponse {
131 let request = RpcRequest::new(function, body.args);
132
133 let metadata = RequestMetadata {
134 request_id: uuid::Uuid::parse_str(&tracing.request_id)
135 .unwrap_or_else(|_| uuid::Uuid::new_v4()),
136 trace_id: tracing.trace_id,
137 client_ip: extract_client_ip(&headers),
138 user_agent: extract_user_agent(&headers),
139 timestamp: chrono::Utc::now(),
140 };
141
142 handler.handle(request, auth, metadata).await
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 fn create_mock_pool() -> sqlx::PgPool {
150 sqlx::postgres::PgPoolOptions::new()
151 .max_connections(1)
152 .connect_lazy("postgres://localhost/nonexistent")
153 .expect("Failed to create mock pool")
154 }
155
156 fn create_test_handler() -> RpcHandler {
157 let registry = FunctionRegistry::new();
158 let db_pool = create_mock_pool();
159 RpcHandler::new(registry, db_pool)
160 }
161
162 #[tokio::test]
163 async fn test_handle_unknown_function() {
164 let handler = create_test_handler();
165 let request = RpcRequest::new("unknown_function", serde_json::json!({}));
166 let auth = AuthContext::unauthenticated();
167 let metadata = RequestMetadata::new();
168
169 let response = handler.handle(request, auth, metadata).await;
170
171 assert!(!response.success);
172 assert!(response.error.is_some());
173 assert_eq!(response.error.as_ref().unwrap().code, "NOT_FOUND");
174 }
175}