1use std::sync::Arc;
2
3use axum::{
4 Json,
5 extract::{Extension, State},
6 http::{HeaderMap, header::USER_AGENT},
7};
8use forge_core::function::{AuthContext, JobDispatch, RequestMetadata, WorkflowDispatch};
9
10use super::request::{BatchRpcRequest, BatchRpcResponse, RpcRequest};
11use super::response::{RpcError, RpcResponse};
12use super::tracing::TracingState;
13use crate::db::Database;
14use crate::function::{FunctionExecutor, FunctionRegistry};
15
16#[derive(Clone)]
18pub struct RpcHandler {
19 executor: Arc<FunctionExecutor>,
21}
22
23impl RpcHandler {
24 pub fn new(registry: FunctionRegistry, db: Database) -> Self {
26 let executor = FunctionExecutor::new(Arc::new(registry), db);
27 Self {
28 executor: Arc::new(executor),
29 }
30 }
31
32 pub fn with_dispatch(
34 registry: FunctionRegistry,
35 db: Database,
36 job_dispatcher: Option<Arc<dyn JobDispatch>>,
37 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
38 ) -> Self {
39 Self::with_dispatch_and_issuer(registry, db, job_dispatcher, workflow_dispatcher, None)
40 }
41
42 pub fn with_dispatch_and_issuer(
44 registry: FunctionRegistry,
45 db: Database,
46 job_dispatcher: Option<Arc<dyn JobDispatch>>,
47 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
48 token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
49 ) -> Self {
50 let executor = FunctionExecutor::with_dispatch_and_issuer(
51 Arc::new(registry),
52 db,
53 job_dispatcher,
54 workflow_dispatcher,
55 token_issuer,
56 );
57 Self {
58 executor: Arc::new(executor),
59 }
60 }
61
62 pub fn set_token_ttl(&mut self, ttl: forge_core::AuthTokenTtl) {
64 if let Some(executor) = Arc::get_mut(&mut self.executor) {
65 executor.set_token_ttl(ttl);
66 }
67 }
68
69 pub async fn handle(
71 &self,
72 request: RpcRequest,
73 auth: AuthContext,
74 metadata: RequestMetadata,
75 ) -> RpcResponse {
76 match self
78 .executor
79 .execute(&request.function, request.args, auth, metadata.clone())
80 .await
81 {
82 Ok(exec_result) => RpcResponse::success(exec_result.result)
83 .with_request_id(metadata.request_id.to_string()),
84 Err(e) => RpcResponse::error(RpcError::from(e))
85 .with_request_id(metadata.request_id.to_string()),
86 }
87 }
88}
89
90fn extract_client_ip(headers: &HeaderMap) -> Option<String> {
92 headers
93 .get("x-forwarded-for")
94 .and_then(|v| v.to_str().ok())
95 .map(|s| s.split(',').next().unwrap_or("").trim().to_string())
96 .filter(|s| !s.is_empty())
97 .or_else(|| {
98 headers
99 .get("x-real-ip")
100 .and_then(|v| v.to_str().ok())
101 .map(|s| s.trim().to_string())
102 .filter(|s| !s.is_empty())
103 })
104}
105
106fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
108 headers
109 .get(USER_AGENT)
110 .and_then(|v| v.to_str().ok())
111 .map(String::from)
112}
113
114fn build_metadata(tracing: TracingState, headers: &HeaderMap) -> RequestMetadata {
116 RequestMetadata {
117 request_id: uuid::Uuid::parse_str(&tracing.request_id)
118 .unwrap_or_else(|_| uuid::Uuid::new_v4()),
119 trace_id: tracing.trace_id,
120 client_ip: extract_client_ip(headers),
121 user_agent: extract_user_agent(headers),
122 timestamp: chrono::Utc::now(),
123 }
124}
125
126pub async fn rpc_handler(
128 State(handler): State<Arc<RpcHandler>>,
129 Extension(auth): Extension<AuthContext>,
130 Extension(tracing): Extension<TracingState>,
131 headers: HeaderMap,
132 Json(request): Json<RpcRequest>,
133) -> RpcResponse {
134 handler
135 .handle(request, auth, build_metadata(tracing, &headers))
136 .await
137}
138
139#[derive(Debug, serde::Deserialize)]
141pub struct RpcFunctionBody {
142 #[serde(default)]
144 pub args: serde_json::Value,
145}
146
147pub async fn rpc_function_handler(
149 State(handler): State<Arc<RpcHandler>>,
150 Extension(auth): Extension<AuthContext>,
151 Extension(tracing): Extension<TracingState>,
152 headers: HeaderMap,
153 axum::extract::Path(function): axum::extract::Path<String>,
154 Json(body): Json<RpcFunctionBody>,
155) -> RpcResponse {
156 let request = RpcRequest::new(function, body.args);
157 handler
158 .handle(request, auth, build_metadata(tracing, &headers))
159 .await
160}
161
162pub async fn rpc_batch_handler(
164 State(handler): State<Arc<RpcHandler>>,
165 Extension(auth): Extension<AuthContext>,
166 Extension(tracing): Extension<TracingState>,
167 headers: HeaderMap,
168 Json(batch): Json<BatchRpcRequest>,
169) -> BatchRpcResponse {
170 let client_ip = extract_client_ip(&headers);
171 let user_agent = extract_user_agent(&headers);
172 let mut results = Vec::with_capacity(batch.requests.len());
173
174 for request in batch.requests {
175 let metadata = RequestMetadata {
176 request_id: uuid::Uuid::new_v4(),
177 trace_id: tracing.trace_id.clone(),
178 client_ip: client_ip.clone(),
179 user_agent: user_agent.clone(),
180 timestamp: chrono::Utc::now(),
181 };
182
183 let response = handler.handle(request, auth.clone(), metadata).await;
184 results.push(response);
185 }
186
187 BatchRpcResponse { results }
188}
189
190#[cfg(test)]
191#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
192mod tests {
193 use super::*;
194
195 fn create_mock_pool() -> sqlx::PgPool {
196 sqlx::postgres::PgPoolOptions::new()
197 .max_connections(1)
198 .connect_lazy("postgres://localhost/nonexistent")
199 .expect("Failed to create mock pool")
200 }
201
202 fn create_test_handler() -> RpcHandler {
203 let registry = FunctionRegistry::new();
204 let db = Database::from_pool(create_mock_pool());
205 RpcHandler::new(registry, db)
206 }
207
208 #[tokio::test]
209 async fn test_handle_unknown_function() {
210 let handler = create_test_handler();
211 let request = RpcRequest::new("unknown_function", serde_json::json!({}));
212 let auth = AuthContext::unauthenticated();
213 let metadata = RequestMetadata::new();
214
215 let response = handler.handle(request, auth, metadata).await;
216
217 assert!(!response.success);
218 assert!(response.error.is_some());
219 assert_eq!(response.error.as_ref().unwrap().code, "NOT_FOUND");
220 }
221}