1use std::sync::Arc;
2
3use axum::{
4 Json,
5 extract::{Extension, State},
6 http::{HeaderMap, header::USER_AGENT},
7};
8use forge_core::function::{AuthContext, JobDispatch, RequestMetadata, WorkflowDispatch};
9
10use super::request::{BatchRpcRequest, BatchRpcResponse, RpcRequest};
11use super::response::{RpcError, RpcResponse};
12use super::tracing::TracingState;
13use crate::db::Database;
14use crate::function::{FunctionExecutor, FunctionRegistry};
15
16#[derive(Clone)]
18pub struct RpcHandler {
19 executor: Arc<FunctionExecutor>,
21}
22
23impl RpcHandler {
24 pub fn new(registry: FunctionRegistry, db: Database) -> Self {
26 let executor = FunctionExecutor::new(Arc::new(registry), db);
27 Self {
28 executor: Arc::new(executor),
29 }
30 }
31
32 pub fn with_dispatch(
34 registry: FunctionRegistry,
35 db: Database,
36 job_dispatcher: Option<Arc<dyn JobDispatch>>,
37 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
38 ) -> Self {
39 Self::with_dispatch_and_issuer(registry, db, job_dispatcher, workflow_dispatcher, None)
40 }
41
42 pub fn with_dispatch_and_issuer(
44 registry: FunctionRegistry,
45 db: Database,
46 job_dispatcher: Option<Arc<dyn JobDispatch>>,
47 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
48 token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
49 ) -> Self {
50 let executor = FunctionExecutor::with_dispatch_and_issuer(
51 Arc::new(registry),
52 db,
53 job_dispatcher,
54 workflow_dispatcher,
55 token_issuer,
56 );
57 Self {
58 executor: Arc::new(executor),
59 }
60 }
61
62 pub async fn handle(
64 &self,
65 request: RpcRequest,
66 auth: AuthContext,
67 metadata: RequestMetadata,
68 ) -> RpcResponse {
69 match self
71 .executor
72 .execute(&request.function, request.args, auth, metadata.clone())
73 .await
74 {
75 Ok(exec_result) => RpcResponse::success(exec_result.result)
76 .with_request_id(metadata.request_id.to_string()),
77 Err(e) => RpcResponse::error(RpcError::from(e))
78 .with_request_id(metadata.request_id.to_string()),
79 }
80 }
81}
82
83fn extract_client_ip(headers: &HeaderMap) -> Option<String> {
85 headers
86 .get("x-forwarded-for")
87 .and_then(|v| v.to_str().ok())
88 .map(|s| s.split(',').next().unwrap_or("").trim().to_string())
89 .filter(|s| !s.is_empty())
90 .or_else(|| {
91 headers
92 .get("x-real-ip")
93 .and_then(|v| v.to_str().ok())
94 .map(|s| s.trim().to_string())
95 .filter(|s| !s.is_empty())
96 })
97}
98
99fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
101 headers
102 .get(USER_AGENT)
103 .and_then(|v| v.to_str().ok())
104 .map(String::from)
105}
106
107pub async fn rpc_handler(
109 State(handler): State<Arc<RpcHandler>>,
110 Extension(auth): Extension<AuthContext>,
111 Extension(tracing): Extension<TracingState>,
112 headers: HeaderMap,
113 Json(request): Json<RpcRequest>,
114) -> RpcResponse {
115 let metadata = RequestMetadata {
116 request_id: uuid::Uuid::parse_str(&tracing.request_id)
117 .unwrap_or_else(|_| uuid::Uuid::new_v4()),
118 trace_id: tracing.trace_id,
119 client_ip: extract_client_ip(&headers),
120 user_agent: extract_user_agent(&headers),
121 timestamp: chrono::Utc::now(),
122 };
123
124 handler.handle(request, auth, metadata).await
125}
126
127#[derive(Debug, serde::Deserialize)]
129pub struct RpcFunctionBody {
130 #[serde(default)]
132 pub args: serde_json::Value,
133}
134
135pub async fn rpc_function_handler(
137 State(handler): State<Arc<RpcHandler>>,
138 Extension(auth): Extension<AuthContext>,
139 Extension(tracing): Extension<TracingState>,
140 headers: HeaderMap,
141 axum::extract::Path(function): axum::extract::Path<String>,
142 Json(body): Json<RpcFunctionBody>,
143) -> RpcResponse {
144 let request = RpcRequest::new(function, body.args);
145
146 let metadata = RequestMetadata {
147 request_id: uuid::Uuid::parse_str(&tracing.request_id)
148 .unwrap_or_else(|_| uuid::Uuid::new_v4()),
149 trace_id: tracing.trace_id,
150 client_ip: extract_client_ip(&headers),
151 user_agent: extract_user_agent(&headers),
152 timestamp: chrono::Utc::now(),
153 };
154
155 handler.handle(request, auth, metadata).await
156}
157
158pub async fn rpc_batch_handler(
160 State(handler): State<Arc<RpcHandler>>,
161 Extension(auth): Extension<AuthContext>,
162 Extension(tracing): Extension<TracingState>,
163 headers: HeaderMap,
164 Json(batch): Json<BatchRpcRequest>,
165) -> BatchRpcResponse {
166 let client_ip = extract_client_ip(&headers);
167 let user_agent = extract_user_agent(&headers);
168 let mut results = Vec::with_capacity(batch.requests.len());
169
170 for request in batch.requests {
171 let metadata = RequestMetadata {
172 request_id: uuid::Uuid::new_v4(),
173 trace_id: tracing.trace_id.clone(),
174 client_ip: client_ip.clone(),
175 user_agent: user_agent.clone(),
176 timestamp: chrono::Utc::now(),
177 };
178
179 let response = handler.handle(request, auth.clone(), metadata).await;
180 results.push(response);
181 }
182
183 BatchRpcResponse { results }
184}
185
186#[cfg(test)]
187#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
188mod tests {
189 use super::*;
190
191 fn create_mock_pool() -> sqlx::PgPool {
192 sqlx::postgres::PgPoolOptions::new()
193 .max_connections(1)
194 .connect_lazy("postgres://localhost/nonexistent")
195 .expect("Failed to create mock pool")
196 }
197
198 fn create_test_handler() -> RpcHandler {
199 let registry = FunctionRegistry::new();
200 let db = Database::from_pool(create_mock_pool());
201 RpcHandler::new(registry, db)
202 }
203
204 #[tokio::test]
205 async fn test_handle_unknown_function() {
206 let handler = create_test_handler();
207 let request = RpcRequest::new("unknown_function", serde_json::json!({}));
208 let auth = AuthContext::unauthenticated();
209 let metadata = RequestMetadata::new();
210
211 let response = handler.handle(request, auth, metadata).await;
212
213 assert!(!response.success);
214 assert!(response.error.is_some());
215 assert_eq!(response.error.as_ref().unwrap().code, "NOT_FOUND");
216 }
217}