1use std::sync::Arc;
2
3use axum::{
4 Json,
5 extract::{Extension, State},
6 http::{HeaderMap, header::USER_AGENT},
7};
8use forge_core::function::{
9 AuthContext, FunctionInfo, JobDispatch, RequestMetadata, WorkflowDispatch,
10};
11
12use super::request::{BatchRpcRequest, BatchRpcResponse, RpcRequest};
13use super::response::{RpcError, RpcResponse};
14use super::tracing::TracingState;
15use crate::db::Database;
16use crate::function::{FunctionExecutor, FunctionRegistry};
17
18#[derive(Clone)]
20pub struct RpcHandler {
21 executor: Arc<FunctionExecutor>,
23}
24
25impl RpcHandler {
26 pub fn new(registry: FunctionRegistry, db: Database) -> Self {
28 let executor = FunctionExecutor::new(Arc::new(registry), db);
29 Self {
30 executor: Arc::new(executor),
31 }
32 }
33
34 pub fn with_dispatch(
36 registry: FunctionRegistry,
37 db: Database,
38 job_dispatcher: Option<Arc<dyn JobDispatch>>,
39 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
40 ) -> Self {
41 Self::with_dispatch_and_issuer(registry, db, job_dispatcher, workflow_dispatcher, None)
42 }
43
44 pub fn with_dispatch_and_issuer(
46 registry: FunctionRegistry,
47 db: Database,
48 job_dispatcher: Option<Arc<dyn JobDispatch>>,
49 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
50 token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
51 ) -> Self {
52 let executor = FunctionExecutor::with_dispatch_and_issuer(
53 Arc::new(registry),
54 db,
55 job_dispatcher,
56 workflow_dispatcher,
57 token_issuer,
58 );
59 Self {
60 executor: Arc::new(executor),
61 }
62 }
63
64 pub fn set_token_ttl(&mut self, ttl: forge_core::AuthTokenTtl) {
66 if let Some(executor) = Arc::get_mut(&mut self.executor) {
67 executor.set_token_ttl(ttl);
68 }
69 }
70
71 pub fn function_info(&self, name: &str) -> Option<FunctionInfo> {
73 self.executor.function_info(name)
74 }
75
76 pub fn set_signals_collector(
78 &mut self,
79 collector: crate::signals::SignalsCollector,
80 server_secret: String,
81 ) {
82 if let Some(executor) = Arc::get_mut(&mut self.executor) {
83 executor.set_signals_collector(collector, server_secret);
84 }
85 }
86
87 pub async fn handle(
89 &self,
90 request: RpcRequest,
91 auth: AuthContext,
92 metadata: RequestMetadata,
93 ) -> RpcResponse {
94 match self
96 .executor
97 .execute(&request.function, request.args, auth, metadata.clone())
98 .await
99 {
100 Ok(exec_result) => RpcResponse::success(exec_result.result)
101 .with_request_id(metadata.request_id.to_string()),
102 Err(e) => RpcResponse::error(RpcError::from(e))
103 .with_request_id(metadata.request_id.to_string()),
104 }
105 }
106}
107
108use super::extract_client_ip;
109
110fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
112 headers
113 .get(USER_AGENT)
114 .and_then(|v| v.to_str().ok())
115 .map(String::from)
116}
117
118fn build_metadata(tracing: TracingState, headers: &HeaderMap) -> RequestMetadata {
120 RequestMetadata {
121 request_id: uuid::Uuid::parse_str(&tracing.request_id)
122 .unwrap_or_else(|_| uuid::Uuid::new_v4()),
123 trace_id: tracing.trace_id,
124 client_ip: extract_client_ip(headers),
125 user_agent: extract_user_agent(headers),
126 correlation_id: extract_correlation_id(headers),
127 timestamp: chrono::Utc::now(),
128 }
129}
130
131fn extract_correlation_id(headers: &HeaderMap) -> Option<String> {
133 headers
134 .get("x-correlation-id")
135 .and_then(|v| v.to_str().ok())
136 .filter(|v| !v.is_empty() && v.len() <= 64)
137 .map(String::from)
138}
139
140pub async fn rpc_handler(
142 State(handler): State<Arc<RpcHandler>>,
143 Extension(auth): Extension<AuthContext>,
144 Extension(tracing): Extension<TracingState>,
145 headers: HeaderMap,
146 Json(request): Json<RpcRequest>,
147) -> RpcResponse {
148 if !is_valid_function_name(&request.function) {
149 return RpcResponse::error(RpcError::validation(
150 "Invalid function name: must be 1-256 alphanumeric characters, underscores, dots, colons, or hyphens",
151 ));
152 }
153 handler
154 .handle(request, auth, build_metadata(tracing, &headers))
155 .await
156}
157
158#[derive(Debug, serde::Deserialize)]
160pub struct RpcFunctionBody {
161 #[serde(default)]
163 pub args: serde_json::Value,
164}
165
166fn is_valid_function_name(name: &str) -> bool {
169 !name.is_empty()
170 && name.len() <= 256
171 && name
172 .chars()
173 .all(|c| c.is_alphanumeric() || c == '_' || c == '.' || c == ':' || c == '-')
174}
175
176pub async fn rpc_function_handler(
178 State(handler): State<Arc<RpcHandler>>,
179 Extension(auth): Extension<AuthContext>,
180 Extension(tracing): Extension<TracingState>,
181 headers: HeaderMap,
182 axum::extract::Path(function): axum::extract::Path<String>,
183 Json(body): Json<RpcFunctionBody>,
184) -> RpcResponse {
185 if !is_valid_function_name(&function) {
186 return RpcResponse::error(RpcError::validation(
187 "Invalid function name: must be 1-256 alphanumeric characters, underscores, dots, colons, or hyphens",
188 ));
189 }
190 let request = RpcRequest::new(function, body.args);
191 handler
192 .handle(request, auth, build_metadata(tracing, &headers))
193 .await
194}
195
196const MAX_BATCH_SIZE: usize = 100;
198
199pub async fn rpc_batch_handler(
201 State(handler): State<Arc<RpcHandler>>,
202 Extension(auth): Extension<AuthContext>,
203 Extension(tracing): Extension<TracingState>,
204 headers: HeaderMap,
205 Json(batch): Json<BatchRpcRequest>,
206) -> BatchRpcResponse {
207 if batch.requests.len() > MAX_BATCH_SIZE {
209 return BatchRpcResponse {
210 results: vec![RpcResponse::error(RpcError::validation(format!(
211 "Batch size {} exceeds maximum of {}",
212 batch.requests.len(),
213 MAX_BATCH_SIZE
214 )))],
215 };
216 }
217
218 let client_ip = extract_client_ip(&headers);
219 let user_agent = extract_user_agent(&headers);
220 let correlation_id = extract_correlation_id(&headers);
221 let mut results = Vec::with_capacity(batch.requests.len());
222
223 for request in batch.requests {
224 if !is_valid_function_name(&request.function) {
226 results.push(RpcResponse::error(RpcError::validation(
227 "Invalid function name: must be 1-256 alphanumeric characters, underscores, dots, colons, or hyphens",
228 )));
229 continue;
230 }
231 let metadata = RequestMetadata {
232 request_id: uuid::Uuid::new_v4(),
233 trace_id: tracing.trace_id.clone(),
234 client_ip: client_ip.clone(),
235 user_agent: user_agent.clone(),
236 correlation_id: correlation_id.clone(),
237 timestamp: chrono::Utc::now(),
238 };
239
240 let response = handler.handle(request, auth.clone(), metadata).await;
241 results.push(response);
242 }
243
244 BatchRpcResponse { results }
245}
246
247#[cfg(test)]
248#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
249mod tests {
250 use super::*;
251
252 fn create_mock_pool() -> sqlx::PgPool {
253 sqlx::postgres::PgPoolOptions::new()
254 .max_connections(1)
255 .connect_lazy("postgres://localhost/nonexistent")
256 .expect("Failed to create mock pool")
257 }
258
259 fn create_test_handler() -> RpcHandler {
260 let registry = FunctionRegistry::new();
261 let db = Database::from_pool(create_mock_pool());
262 RpcHandler::new(registry, db)
263 }
264
265 #[tokio::test]
266 async fn test_handle_unknown_function() {
267 let handler = create_test_handler();
268 let request = RpcRequest::new("unknown_function", serde_json::json!({}));
269 let auth = AuthContext::unauthenticated();
270 let metadata = RequestMetadata::new();
271
272 let response = handler.handle(request, auth, metadata).await;
273
274 assert!(!response.success);
275 assert!(response.error.is_some());
276 assert_eq!(response.error.as_ref().unwrap().code, "NOT_FOUND");
277 }
278}