Skip to main content

forge_runtime/gateway/
rpc.rs

1use std::sync::Arc;
2
3use axum::{
4    Json,
5    extract::{Extension, State},
6    http::{HeaderMap, header::USER_AGENT},
7};
8use forge_core::function::{
9    AuthContext, FunctionInfo, JobDispatch, RequestMetadata, WorkflowDispatch,
10};
11
12use super::request::{BatchRpcRequest, BatchRpcResponse, RpcRequest};
13use super::response::{RpcError, RpcResponse};
14use super::tracing::TracingState;
15use crate::db::Database;
16use crate::function::{FunctionExecutor, FunctionRegistry};
17
18/// RPC handler for function invocations.
19#[derive(Clone)]
20pub struct RpcHandler {
21    /// Function executor.
22    executor: Arc<FunctionExecutor>,
23}
24
25impl RpcHandler {
26    /// Create a new RPC handler.
27    pub fn new(registry: FunctionRegistry, db: Database) -> Self {
28        let executor = FunctionExecutor::new(Arc::new(registry), db);
29        Self {
30            executor: Arc::new(executor),
31        }
32    }
33
34    /// Create a new RPC handler with dispatch capabilities.
35    pub fn with_dispatch(
36        registry: FunctionRegistry,
37        db: Database,
38        job_dispatcher: Option<Arc<dyn JobDispatch>>,
39        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
40    ) -> Self {
41        Self::with_dispatch_and_issuer(registry, db, job_dispatcher, workflow_dispatcher, None)
42    }
43
44    /// Create a new RPC handler with dispatch and token issuer.
45    pub fn with_dispatch_and_issuer(
46        registry: FunctionRegistry,
47        db: Database,
48        job_dispatcher: Option<Arc<dyn JobDispatch>>,
49        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
50        token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
51    ) -> Self {
52        let executor = FunctionExecutor::with_dispatch_and_issuer(
53            Arc::new(registry),
54            db,
55            job_dispatcher,
56            workflow_dispatcher,
57            token_issuer,
58        );
59        Self {
60            executor: Arc::new(executor),
61        }
62    }
63
64    /// Set the token TTL config. Must be called before any requests are handled.
65    pub fn set_token_ttl(&mut self, ttl: forge_core::AuthTokenTtl) {
66        if let Some(executor) = Arc::get_mut(&mut self.executor) {
67            executor.set_token_ttl(ttl);
68        }
69    }
70
71    /// Look up function metadata by name.
72    pub fn function_info(&self, name: &str) -> Option<FunctionInfo> {
73        self.executor.function_info(name)
74    }
75
76    /// Set the signals collector for auto-capturing RPC events.
77    pub fn set_signals_collector(
78        &mut self,
79        collector: crate::signals::SignalsCollector,
80        server_secret: String,
81    ) {
82        if let Some(executor) = Arc::get_mut(&mut self.executor) {
83            executor.set_signals_collector(collector, server_secret);
84        }
85    }
86
87    /// Handle an RPC request.
88    pub async fn handle(
89        &self,
90        request: RpcRequest,
91        auth: AuthContext,
92        metadata: RequestMetadata,
93    ) -> RpcResponse {
94        // Don't check has_function early - let executor try jobs/workflows too
95        match self
96            .executor
97            .execute(&request.function, request.args, auth, metadata.clone())
98            .await
99        {
100            Ok(exec_result) => RpcResponse::success(exec_result.result)
101                .with_request_id(metadata.request_id.to_string()),
102            Err(e) => RpcResponse::error(RpcError::from(e))
103                .with_request_id(metadata.request_id.to_string()),
104        }
105    }
106}
107
108use super::extract_client_ip;
109
110/// Extract user agent from headers.
111fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
112    headers
113        .get(USER_AGENT)
114        .and_then(|v| v.to_str().ok())
115        .map(String::from)
116}
117
118/// Build request metadata from tracing state and headers.
119fn build_metadata(tracing: TracingState, headers: &HeaderMap) -> RequestMetadata {
120    RequestMetadata {
121        request_id: uuid::Uuid::parse_str(&tracing.request_id)
122            .unwrap_or_else(|_| uuid::Uuid::new_v4()),
123        trace_id: tracing.trace_id,
124        client_ip: extract_client_ip(headers),
125        user_agent: extract_user_agent(headers),
126        correlation_id: extract_correlation_id(headers),
127        timestamp: chrono::Utc::now(),
128    }
129}
130
131/// Extract the correlation ID from the x-correlation-id header.
132fn extract_correlation_id(headers: &HeaderMap) -> Option<String> {
133    headers
134        .get("x-correlation-id")
135        .and_then(|v| v.to_str().ok())
136        .filter(|v| !v.is_empty() && v.len() <= 64)
137        .map(String::from)
138}
139
140/// Axum handler for POST /rpc.
141pub async fn rpc_handler(
142    State(handler): State<Arc<RpcHandler>>,
143    Extension(auth): Extension<AuthContext>,
144    Extension(tracing): Extension<TracingState>,
145    headers: HeaderMap,
146    Json(request): Json<RpcRequest>,
147) -> RpcResponse {
148    if !is_valid_function_name(&request.function) {
149        return RpcResponse::error(RpcError::validation(
150            "Invalid function name: must be 1-256 alphanumeric characters, underscores, dots, colons, or hyphens",
151        ));
152    }
153    handler
154        .handle(request, auth, build_metadata(tracing, &headers))
155        .await
156}
157
158/// Request body wrapper for REST-style RPC calls.
159#[derive(Debug, serde::Deserialize)]
160pub struct RpcFunctionBody {
161    /// Function arguments.
162    #[serde(default)]
163    pub args: serde_json::Value,
164}
165
166/// Validate that a function name contains only safe characters.
167/// Prevents log injection and unexpected behavior from special characters.
168fn is_valid_function_name(name: &str) -> bool {
169    !name.is_empty()
170        && name.len() <= 256
171        && name
172            .chars()
173            .all(|c| c.is_alphanumeric() || c == '_' || c == '.' || c == ':' || c == '-')
174}
175
176/// Axum handler for POST /rpc/:function (REST-style).
177pub async fn rpc_function_handler(
178    State(handler): State<Arc<RpcHandler>>,
179    Extension(auth): Extension<AuthContext>,
180    Extension(tracing): Extension<TracingState>,
181    headers: HeaderMap,
182    axum::extract::Path(function): axum::extract::Path<String>,
183    Json(body): Json<RpcFunctionBody>,
184) -> RpcResponse {
185    if !is_valid_function_name(&function) {
186        return RpcResponse::error(RpcError::validation(
187            "Invalid function name: must be 1-256 alphanumeric characters, underscores, dots, colons, or hyphens",
188        ));
189    }
190    let request = RpcRequest::new(function, body.args);
191    handler
192        .handle(request, auth, build_metadata(tracing, &headers))
193        .await
194}
195
196/// Maximum number of requests allowed in a single batch.
197const MAX_BATCH_SIZE: usize = 100;
198
199/// Axum handler for POST /rpc/batch.
200pub async fn rpc_batch_handler(
201    State(handler): State<Arc<RpcHandler>>,
202    Extension(auth): Extension<AuthContext>,
203    Extension(tracing): Extension<TracingState>,
204    headers: HeaderMap,
205    Json(batch): Json<BatchRpcRequest>,
206) -> BatchRpcResponse {
207    // Prevent DoS via unbounded batch size
208    if batch.requests.len() > MAX_BATCH_SIZE {
209        return BatchRpcResponse {
210            results: vec![RpcResponse::error(RpcError::validation(format!(
211                "Batch size {} exceeds maximum of {}",
212                batch.requests.len(),
213                MAX_BATCH_SIZE
214            )))],
215        };
216    }
217
218    let client_ip = extract_client_ip(&headers);
219    let user_agent = extract_user_agent(&headers);
220    let correlation_id = extract_correlation_id(&headers);
221    let mut results = Vec::with_capacity(batch.requests.len());
222
223    for request in batch.requests {
224        // Validate function names in batch requests
225        if !is_valid_function_name(&request.function) {
226            results.push(RpcResponse::error(RpcError::validation(
227                "Invalid function name: must be 1-256 alphanumeric characters, underscores, dots, colons, or hyphens",
228            )));
229            continue;
230        }
231        let metadata = RequestMetadata {
232            request_id: uuid::Uuid::new_v4(),
233            trace_id: tracing.trace_id.clone(),
234            client_ip: client_ip.clone(),
235            user_agent: user_agent.clone(),
236            correlation_id: correlation_id.clone(),
237            timestamp: chrono::Utc::now(),
238        };
239
240        let response = handler.handle(request, auth.clone(), metadata).await;
241        results.push(response);
242    }
243
244    BatchRpcResponse { results }
245}
246
247#[cfg(test)]
248#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
249mod tests {
250    use super::*;
251
252    fn create_mock_pool() -> sqlx::PgPool {
253        sqlx::postgres::PgPoolOptions::new()
254            .max_connections(1)
255            .connect_lazy("postgres://localhost/nonexistent")
256            .expect("Failed to create mock pool")
257    }
258
259    fn create_test_handler() -> RpcHandler {
260        let registry = FunctionRegistry::new();
261        let db = Database::from_pool(create_mock_pool());
262        RpcHandler::new(registry, db)
263    }
264
265    #[tokio::test]
266    async fn test_handle_unknown_function() {
267        let handler = create_test_handler();
268        let request = RpcRequest::new("unknown_function", serde_json::json!({}));
269        let auth = AuthContext::unauthenticated();
270        let metadata = RequestMetadata::new();
271
272        let response = handler.handle(request, auth, metadata).await;
273
274        assert!(!response.success);
275        assert!(response.error.is_some());
276        assert_eq!(response.error.as_ref().unwrap().code, "NOT_FOUND");
277    }
278}