forge_runtime/gateway/
rpc.rs1use std::sync::Arc;
2
3use axum::{
4 Json,
5 extract::{Extension, State},
6 http::{HeaderMap, header::USER_AGENT},
7};
8use forge_core::function::{AuthContext, JobDispatch, RequestMetadata, WorkflowDispatch};
9
10use super::request::RpcRequest;
11use super::response::{RpcError, RpcResponse};
12use super::tracing::TracingState;
13use crate::db::Database;
14use crate::function::{FunctionExecutor, FunctionRegistry};
15
16#[derive(Clone)]
18pub struct RpcHandler {
19 executor: Arc<FunctionExecutor>,
21}
22
23impl RpcHandler {
24 pub fn new(registry: FunctionRegistry, db: Database) -> Self {
26 let executor = FunctionExecutor::new(Arc::new(registry), db);
27 Self {
28 executor: Arc::new(executor),
29 }
30 }
31
32 pub fn with_dispatch(
34 registry: FunctionRegistry,
35 db: Database,
36 job_dispatcher: Option<Arc<dyn JobDispatch>>,
37 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
38 ) -> Self {
39 Self::with_dispatch_and_issuer(registry, db, job_dispatcher, workflow_dispatcher, None)
40 }
41
42 pub fn with_dispatch_and_issuer(
44 registry: FunctionRegistry,
45 db: Database,
46 job_dispatcher: Option<Arc<dyn JobDispatch>>,
47 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
48 token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
49 ) -> Self {
50 let executor = FunctionExecutor::with_dispatch_and_issuer(
51 Arc::new(registry),
52 db,
53 job_dispatcher,
54 workflow_dispatcher,
55 token_issuer,
56 );
57 Self {
58 executor: Arc::new(executor),
59 }
60 }
61
62 pub async fn handle(
64 &self,
65 request: RpcRequest,
66 auth: AuthContext,
67 metadata: RequestMetadata,
68 ) -> RpcResponse {
69 match self
71 .executor
72 .execute(&request.function, request.args, auth, metadata.clone())
73 .await
74 {
75 Ok(exec_result) => RpcResponse::success(exec_result.result)
76 .with_request_id(metadata.request_id.to_string()),
77 Err(e) => RpcResponse::error(RpcError::from(e))
78 .with_request_id(metadata.request_id.to_string()),
79 }
80 }
81}
82
83fn extract_client_ip(headers: &HeaderMap) -> Option<String> {
85 headers
86 .get("x-forwarded-for")
87 .and_then(|v| v.to_str().ok())
88 .map(|s| s.split(',').next().unwrap_or("").trim().to_string())
89 .filter(|s| !s.is_empty())
90 .or_else(|| {
91 headers
92 .get("x-real-ip")
93 .and_then(|v| v.to_str().ok())
94 .map(|s| s.trim().to_string())
95 .filter(|s| !s.is_empty())
96 })
97}
98
99fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
101 headers
102 .get(USER_AGENT)
103 .and_then(|v| v.to_str().ok())
104 .map(String::from)
105}
106
107pub async fn rpc_handler(
109 State(handler): State<Arc<RpcHandler>>,
110 Extension(auth): Extension<AuthContext>,
111 Extension(tracing): Extension<TracingState>,
112 headers: HeaderMap,
113 Json(request): Json<RpcRequest>,
114) -> RpcResponse {
115 let metadata = RequestMetadata {
116 request_id: uuid::Uuid::parse_str(&tracing.request_id)
117 .unwrap_or_else(|_| uuid::Uuid::new_v4()),
118 trace_id: tracing.trace_id,
119 client_ip: extract_client_ip(&headers),
120 user_agent: extract_user_agent(&headers),
121 timestamp: chrono::Utc::now(),
122 };
123
124 handler.handle(request, auth, metadata).await
125}
126
127#[derive(Debug, serde::Deserialize)]
129pub struct RpcFunctionBody {
130 #[serde(default)]
132 pub args: serde_json::Value,
133}
134
135pub async fn rpc_function_handler(
137 State(handler): State<Arc<RpcHandler>>,
138 Extension(auth): Extension<AuthContext>,
139 Extension(tracing): Extension<TracingState>,
140 headers: HeaderMap,
141 axum::extract::Path(function): axum::extract::Path<String>,
142 Json(body): Json<RpcFunctionBody>,
143) -> RpcResponse {
144 let request = RpcRequest::new(function, body.args);
145
146 let metadata = RequestMetadata {
147 request_id: uuid::Uuid::parse_str(&tracing.request_id)
148 .unwrap_or_else(|_| uuid::Uuid::new_v4()),
149 trace_id: tracing.trace_id,
150 client_ip: extract_client_ip(&headers),
151 user_agent: extract_user_agent(&headers),
152 timestamp: chrono::Utc::now(),
153 };
154
155 handler.handle(request, auth, metadata).await
156}
157
158#[cfg(test)]
159#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
160mod tests {
161 use super::*;
162
163 fn create_mock_pool() -> sqlx::PgPool {
164 sqlx::postgres::PgPoolOptions::new()
165 .max_connections(1)
166 .connect_lazy("postgres://localhost/nonexistent")
167 .expect("Failed to create mock pool")
168 }
169
170 fn create_test_handler() -> RpcHandler {
171 let registry = FunctionRegistry::new();
172 let db = Database::from_pool(create_mock_pool());
173 RpcHandler::new(registry, db)
174 }
175
176 #[tokio::test]
177 async fn test_handle_unknown_function() {
178 let handler = create_test_handler();
179 let request = RpcRequest::new("unknown_function", serde_json::json!({}));
180 let auth = AuthContext::unauthenticated();
181 let metadata = RequestMetadata::new();
182
183 let response = handler.handle(request, auth, metadata).await;
184
185 assert!(!response.success);
186 assert!(response.error.is_some());
187 assert_eq!(response.error.as_ref().unwrap().code, "NOT_FOUND");
188 }
189}