forge_runtime/gateway/
rpc.rs1use std::sync::Arc;
2
3use axum::{
4 Json,
5 extract::{Extension, State},
6 http::{HeaderMap, header::USER_AGENT},
7};
8use forge_core::function::{AuthContext, JobDispatch, RequestMetadata, WorkflowDispatch};
9
10use super::request::RpcRequest;
11use super::response::{RpcError, RpcResponse};
12use super::tracing::TracingState;
13use crate::function::{FunctionExecutor, FunctionRegistry};
14
15#[derive(Clone)]
17pub struct RpcHandler {
18 executor: Arc<FunctionExecutor>,
20}
21
22impl RpcHandler {
23 pub fn new(registry: FunctionRegistry, db_pool: sqlx::PgPool) -> Self {
25 let executor = FunctionExecutor::new(Arc::new(registry), db_pool);
26 Self {
27 executor: Arc::new(executor),
28 }
29 }
30
31 pub fn with_dispatch(
33 registry: FunctionRegistry,
34 db_pool: sqlx::PgPool,
35 job_dispatcher: Option<Arc<dyn JobDispatch>>,
36 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
37 ) -> Self {
38 let executor = FunctionExecutor::with_dispatch(
39 Arc::new(registry),
40 db_pool,
41 job_dispatcher,
42 workflow_dispatcher,
43 );
44 Self {
45 executor: Arc::new(executor),
46 }
47 }
48
49 pub async fn handle(
51 &self,
52 request: RpcRequest,
53 auth: AuthContext,
54 metadata: RequestMetadata,
55 ) -> RpcResponse {
56 if !self.executor.has_function(&request.function) {
57 return RpcResponse::error(RpcError::not_found(format!(
58 "Function '{}' not found",
59 request.function
60 )))
61 .with_request_id(metadata.request_id.to_string());
62 }
63
64 match self
65 .executor
66 .execute(&request.function, request.args, auth, metadata.clone())
67 .await
68 {
69 Ok(exec_result) => {
70 if exec_result.success {
71 RpcResponse::success(exec_result.result)
72 .with_request_id(metadata.request_id.to_string())
73 } else {
74 RpcResponse::error(RpcError::internal(
75 exec_result
76 .error
77 .unwrap_or_else(|| "Unknown error".to_string()),
78 ))
79 .with_request_id(metadata.request_id.to_string())
80 }
81 }
82 Err(e) => RpcResponse::error(RpcError::from(e))
83 .with_request_id(metadata.request_id.to_string()),
84 }
85 }
86}
87
88fn extract_client_ip(headers: &HeaderMap) -> Option<String> {
90 headers
91 .get("x-forwarded-for")
92 .and_then(|v| v.to_str().ok())
93 .map(|s| s.split(',').next().unwrap_or("").trim().to_string())
94 .filter(|s| !s.is_empty())
95 .or_else(|| {
96 headers
97 .get("x-real-ip")
98 .and_then(|v| v.to_str().ok())
99 .map(|s| s.trim().to_string())
100 .filter(|s| !s.is_empty())
101 })
102}
103
104fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
106 headers
107 .get(USER_AGENT)
108 .and_then(|v| v.to_str().ok())
109 .map(String::from)
110}
111
112pub async fn rpc_handler(
114 State(handler): State<Arc<RpcHandler>>,
115 Extension(auth): Extension<AuthContext>,
116 Extension(tracing): Extension<TracingState>,
117 headers: HeaderMap,
118 Json(request): Json<RpcRequest>,
119) -> RpcResponse {
120 let metadata = RequestMetadata {
121 request_id: uuid::Uuid::parse_str(&tracing.request_id)
122 .unwrap_or_else(|_| uuid::Uuid::new_v4()),
123 trace_id: tracing.trace_id,
124 client_ip: extract_client_ip(&headers),
125 user_agent: extract_user_agent(&headers),
126 timestamp: chrono::Utc::now(),
127 };
128
129 handler.handle(request, auth, metadata).await
130}
131
132#[derive(Debug, serde::Deserialize)]
134pub struct RpcFunctionBody {
135 #[serde(default)]
137 pub args: serde_json::Value,
138}
139
140pub async fn rpc_function_handler(
142 State(handler): State<Arc<RpcHandler>>,
143 Extension(auth): Extension<AuthContext>,
144 Extension(tracing): Extension<TracingState>,
145 headers: HeaderMap,
146 axum::extract::Path(function): axum::extract::Path<String>,
147 Json(body): Json<RpcFunctionBody>,
148) -> RpcResponse {
149 let request = RpcRequest::new(function, body.args);
150
151 let metadata = RequestMetadata {
152 request_id: uuid::Uuid::parse_str(&tracing.request_id)
153 .unwrap_or_else(|_| uuid::Uuid::new_v4()),
154 trace_id: tracing.trace_id,
155 client_ip: extract_client_ip(&headers),
156 user_agent: extract_user_agent(&headers),
157 timestamp: chrono::Utc::now(),
158 };
159
160 handler.handle(request, auth, metadata).await
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166
167 fn create_mock_pool() -> sqlx::PgPool {
168 sqlx::postgres::PgPoolOptions::new()
169 .max_connections(1)
170 .connect_lazy("postgres://localhost/nonexistent")
171 .expect("Failed to create mock pool")
172 }
173
174 fn create_test_handler() -> RpcHandler {
175 let registry = FunctionRegistry::new();
176 let db_pool = create_mock_pool();
177 RpcHandler::new(registry, db_pool)
178 }
179
180 #[tokio::test]
181 async fn test_handle_unknown_function() {
182 let handler = create_test_handler();
183 let request = RpcRequest::new("unknown_function", serde_json::json!({}));
184 let auth = AuthContext::unauthenticated();
185 let metadata = RequestMetadata::new();
186
187 let response = handler.handle(request, auth, metadata).await;
188
189 assert!(!response.success);
190 assert!(response.error.is_some());
191 assert_eq!(response.error.as_ref().unwrap().code, "NOT_FOUND");
192 }
193}