forge_runtime/gateway/
rpc.rs1use std::sync::Arc;
2
3use axum::{
4 Json,
5 extract::{Extension, State},
6 http::{HeaderMap, header::USER_AGENT},
7};
8use forge_core::function::{AuthContext, JobDispatch, RequestMetadata, WorkflowDispatch};
9
10use super::request::RpcRequest;
11use super::response::{RpcError, RpcResponse};
12use super::tracing::TracingState;
13use crate::db::Database;
14use crate::function::{FunctionExecutor, FunctionRegistry};
15
16#[derive(Clone)]
18pub struct RpcHandler {
19 executor: Arc<FunctionExecutor>,
21}
22
23impl RpcHandler {
24 pub fn new(registry: FunctionRegistry, db: Database) -> Self {
26 let executor = FunctionExecutor::new(Arc::new(registry), db);
27 Self {
28 executor: Arc::new(executor),
29 }
30 }
31
32 pub fn with_dispatch(
34 registry: FunctionRegistry,
35 db: Database,
36 job_dispatcher: Option<Arc<dyn JobDispatch>>,
37 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
38 ) -> Self {
39 let executor = FunctionExecutor::with_dispatch(
40 Arc::new(registry),
41 db,
42 job_dispatcher,
43 workflow_dispatcher,
44 );
45 Self {
46 executor: Arc::new(executor),
47 }
48 }
49
50 pub async fn handle(
52 &self,
53 request: RpcRequest,
54 auth: AuthContext,
55 metadata: RequestMetadata,
56 ) -> RpcResponse {
57 match self
59 .executor
60 .execute(&request.function, request.args, auth, metadata.clone())
61 .await
62 {
63 Ok(exec_result) => RpcResponse::success(exec_result.result)
64 .with_request_id(metadata.request_id.to_string()),
65 Err(e) => RpcResponse::error(RpcError::from(e))
66 .with_request_id(metadata.request_id.to_string()),
67 }
68 }
69}
70
71fn extract_client_ip(headers: &HeaderMap) -> Option<String> {
73 headers
74 .get("x-forwarded-for")
75 .and_then(|v| v.to_str().ok())
76 .map(|s| s.split(',').next().unwrap_or("").trim().to_string())
77 .filter(|s| !s.is_empty())
78 .or_else(|| {
79 headers
80 .get("x-real-ip")
81 .and_then(|v| v.to_str().ok())
82 .map(|s| s.trim().to_string())
83 .filter(|s| !s.is_empty())
84 })
85}
86
87fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
89 headers
90 .get(USER_AGENT)
91 .and_then(|v| v.to_str().ok())
92 .map(String::from)
93}
94
95pub async fn rpc_handler(
97 State(handler): State<Arc<RpcHandler>>,
98 Extension(auth): Extension<AuthContext>,
99 Extension(tracing): Extension<TracingState>,
100 headers: HeaderMap,
101 Json(request): Json<RpcRequest>,
102) -> RpcResponse {
103 let metadata = RequestMetadata {
104 request_id: uuid::Uuid::parse_str(&tracing.request_id)
105 .unwrap_or_else(|_| uuid::Uuid::new_v4()),
106 trace_id: tracing.trace_id,
107 client_ip: extract_client_ip(&headers),
108 user_agent: extract_user_agent(&headers),
109 timestamp: chrono::Utc::now(),
110 };
111
112 handler.handle(request, auth, metadata).await
113}
114
115#[derive(Debug, serde::Deserialize)]
117pub struct RpcFunctionBody {
118 #[serde(default)]
120 pub args: serde_json::Value,
121}
122
123pub async fn rpc_function_handler(
125 State(handler): State<Arc<RpcHandler>>,
126 Extension(auth): Extension<AuthContext>,
127 Extension(tracing): Extension<TracingState>,
128 headers: HeaderMap,
129 axum::extract::Path(function): axum::extract::Path<String>,
130 Json(body): Json<RpcFunctionBody>,
131) -> RpcResponse {
132 let request = RpcRequest::new(function, body.args);
133
134 let metadata = RequestMetadata {
135 request_id: uuid::Uuid::parse_str(&tracing.request_id)
136 .unwrap_or_else(|_| uuid::Uuid::new_v4()),
137 trace_id: tracing.trace_id,
138 client_ip: extract_client_ip(&headers),
139 user_agent: extract_user_agent(&headers),
140 timestamp: chrono::Utc::now(),
141 };
142
143 handler.handle(request, auth, metadata).await
144}
145
146#[cfg(test)]
147#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
148mod tests {
149 use super::*;
150
151 fn create_mock_pool() -> sqlx::PgPool {
152 sqlx::postgres::PgPoolOptions::new()
153 .max_connections(1)
154 .connect_lazy("postgres://localhost/nonexistent")
155 .expect("Failed to create mock pool")
156 }
157
158 fn create_test_handler() -> RpcHandler {
159 let registry = FunctionRegistry::new();
160 let db = Database::from_pool(create_mock_pool());
161 RpcHandler::new(registry, db)
162 }
163
164 #[tokio::test]
165 async fn test_handle_unknown_function() {
166 let handler = create_test_handler();
167 let request = RpcRequest::new("unknown_function", serde_json::json!({}));
168 let auth = AuthContext::unauthenticated();
169 let metadata = RequestMetadata::new();
170
171 let response = handler.handle(request, auth, metadata).await;
172
173 assert!(!response.success);
174 assert!(response.error.is_some());
175 assert_eq!(response.error.as_ref().unwrap().code, "NOT_FOUND");
176 }
177}