1use std::sync::Arc;
2
3use axum::{
4 Json,
5 extract::{Extension, State},
6 http::{HeaderMap, header::USER_AGENT},
7};
8use forge_core::function::{AuthContext, JobDispatch, RequestMetadata, WorkflowDispatch};
9
10use super::request::{BatchRpcRequest, BatchRpcResponse, RpcRequest};
11use super::response::{RpcError, RpcResponse};
12use super::tracing::TracingState;
13use crate::db::Database;
14use crate::function::{FunctionExecutor, FunctionRegistry};
15
16#[derive(Clone)]
18pub struct RpcHandler {
19 executor: Arc<FunctionExecutor>,
21}
22
23impl RpcHandler {
24 pub fn new(registry: FunctionRegistry, db: Database) -> Self {
26 let executor = FunctionExecutor::new(Arc::new(registry), db);
27 Self {
28 executor: Arc::new(executor),
29 }
30 }
31
32 pub fn with_dispatch(
34 registry: FunctionRegistry,
35 db: Database,
36 job_dispatcher: Option<Arc<dyn JobDispatch>>,
37 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
38 ) -> Self {
39 Self::with_dispatch_and_issuer(registry, db, job_dispatcher, workflow_dispatcher, None)
40 }
41
42 pub fn with_dispatch_and_issuer(
44 registry: FunctionRegistry,
45 db: Database,
46 job_dispatcher: Option<Arc<dyn JobDispatch>>,
47 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
48 token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
49 ) -> Self {
50 let executor = FunctionExecutor::with_dispatch_and_issuer(
51 Arc::new(registry),
52 db,
53 job_dispatcher,
54 workflow_dispatcher,
55 token_issuer,
56 );
57 Self {
58 executor: Arc::new(executor),
59 }
60 }
61
62 pub fn set_token_ttl(&mut self, ttl: forge_core::AuthTokenTtl) {
64 if let Some(executor) = Arc::get_mut(&mut self.executor) {
65 executor.set_token_ttl(ttl);
66 }
67 }
68
69 pub fn set_signals_collector(
71 &mut self,
72 collector: crate::signals::SignalsCollector,
73 server_secret: String,
74 ) {
75 if let Some(executor) = Arc::get_mut(&mut self.executor) {
76 executor.set_signals_collector(collector, server_secret);
77 }
78 }
79
80 pub async fn handle(
82 &self,
83 request: RpcRequest,
84 auth: AuthContext,
85 metadata: RequestMetadata,
86 ) -> RpcResponse {
87 match self
89 .executor
90 .execute(&request.function, request.args, auth, metadata.clone())
91 .await
92 {
93 Ok(exec_result) => RpcResponse::success(exec_result.result)
94 .with_request_id(metadata.request_id.to_string()),
95 Err(e) => RpcResponse::error(RpcError::from(e))
96 .with_request_id(metadata.request_id.to_string()),
97 }
98 }
99}
100
101use super::extract_client_ip;
102
103fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
105 headers
106 .get(USER_AGENT)
107 .and_then(|v| v.to_str().ok())
108 .map(String::from)
109}
110
111fn build_metadata(tracing: TracingState, headers: &HeaderMap) -> RequestMetadata {
113 RequestMetadata {
114 request_id: uuid::Uuid::parse_str(&tracing.request_id)
115 .unwrap_or_else(|_| uuid::Uuid::new_v4()),
116 trace_id: tracing.trace_id,
117 client_ip: extract_client_ip(headers),
118 user_agent: extract_user_agent(headers),
119 correlation_id: extract_correlation_id(headers),
120 timestamp: chrono::Utc::now(),
121 }
122}
123
124fn extract_correlation_id(headers: &HeaderMap) -> Option<String> {
126 headers
127 .get("x-correlation-id")
128 .and_then(|v| v.to_str().ok())
129 .filter(|v| !v.is_empty() && v.len() <= 64)
130 .map(String::from)
131}
132
133pub async fn rpc_handler(
135 State(handler): State<Arc<RpcHandler>>,
136 Extension(auth): Extension<AuthContext>,
137 Extension(tracing): Extension<TracingState>,
138 headers: HeaderMap,
139 Json(request): Json<RpcRequest>,
140) -> RpcResponse {
141 if !is_valid_function_name(&request.function) {
142 return RpcResponse::error(RpcError::validation(
143 "Invalid function name: must be 1-256 alphanumeric characters, underscores, dots, colons, or hyphens",
144 ));
145 }
146 handler
147 .handle(request, auth, build_metadata(tracing, &headers))
148 .await
149}
150
151#[derive(Debug, serde::Deserialize)]
153pub struct RpcFunctionBody {
154 #[serde(default)]
156 pub args: serde_json::Value,
157}
158
159fn is_valid_function_name(name: &str) -> bool {
162 !name.is_empty()
163 && name.len() <= 256
164 && name
165 .chars()
166 .all(|c| c.is_alphanumeric() || c == '_' || c == '.' || c == ':' || c == '-')
167}
168
169pub async fn rpc_function_handler(
171 State(handler): State<Arc<RpcHandler>>,
172 Extension(auth): Extension<AuthContext>,
173 Extension(tracing): Extension<TracingState>,
174 headers: HeaderMap,
175 axum::extract::Path(function): axum::extract::Path<String>,
176 Json(body): Json<RpcFunctionBody>,
177) -> RpcResponse {
178 if !is_valid_function_name(&function) {
179 return RpcResponse::error(RpcError::validation(
180 "Invalid function name: must be 1-256 alphanumeric characters, underscores, dots, colons, or hyphens",
181 ));
182 }
183 let request = RpcRequest::new(function, body.args);
184 handler
185 .handle(request, auth, build_metadata(tracing, &headers))
186 .await
187}
188
189const MAX_BATCH_SIZE: usize = 100;
191
192pub async fn rpc_batch_handler(
194 State(handler): State<Arc<RpcHandler>>,
195 Extension(auth): Extension<AuthContext>,
196 Extension(tracing): Extension<TracingState>,
197 headers: HeaderMap,
198 Json(batch): Json<BatchRpcRequest>,
199) -> BatchRpcResponse {
200 if batch.requests.len() > MAX_BATCH_SIZE {
202 return BatchRpcResponse {
203 results: vec![RpcResponse::error(RpcError::validation(format!(
204 "Batch size {} exceeds maximum of {}",
205 batch.requests.len(),
206 MAX_BATCH_SIZE
207 )))],
208 };
209 }
210
211 let client_ip = extract_client_ip(&headers);
212 let user_agent = extract_user_agent(&headers);
213 let correlation_id = extract_correlation_id(&headers);
214 let mut results = Vec::with_capacity(batch.requests.len());
215
216 for request in batch.requests {
217 if !is_valid_function_name(&request.function) {
219 results.push(RpcResponse::error(RpcError::validation(
220 "Invalid function name: must be 1-256 alphanumeric characters, underscores, dots, colons, or hyphens",
221 )));
222 continue;
223 }
224 let metadata = RequestMetadata {
225 request_id: uuid::Uuid::new_v4(),
226 trace_id: tracing.trace_id.clone(),
227 client_ip: client_ip.clone(),
228 user_agent: user_agent.clone(),
229 correlation_id: correlation_id.clone(),
230 timestamp: chrono::Utc::now(),
231 };
232
233 let response = handler.handle(request, auth.clone(), metadata).await;
234 results.push(response);
235 }
236
237 BatchRpcResponse { results }
238}
239
240#[cfg(test)]
241#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
242mod tests {
243 use super::*;
244
245 fn create_mock_pool() -> sqlx::PgPool {
246 sqlx::postgres::PgPoolOptions::new()
247 .max_connections(1)
248 .connect_lazy("postgres://localhost/nonexistent")
249 .expect("Failed to create mock pool")
250 }
251
252 fn create_test_handler() -> RpcHandler {
253 let registry = FunctionRegistry::new();
254 let db = Database::from_pool(create_mock_pool());
255 RpcHandler::new(registry, db)
256 }
257
258 #[tokio::test]
259 async fn test_handle_unknown_function() {
260 let handler = create_test_handler();
261 let request = RpcRequest::new("unknown_function", serde_json::json!({}));
262 let auth = AuthContext::unauthenticated();
263 let metadata = RequestMetadata::new();
264
265 let response = handler.handle(request, auth, metadata).await;
266
267 assert!(!response.success);
268 assert!(response.error.is_some());
269 assert_eq!(response.error.as_ref().unwrap().code, "NOT_FOUND");
270 }
271}