Skip to main content

forge_runtime/gateway/
rpc.rs

1use std::sync::Arc;
2
3use axum::{
4    Json,
5    extract::{Extension, State},
6    http::{HeaderMap, header::USER_AGENT},
7};
8use forge_core::function::{AuthContext, JobDispatch, RequestMetadata, WorkflowDispatch};
9
10use super::request::{BatchRpcRequest, BatchRpcResponse, RpcRequest};
11use super::response::{RpcError, RpcResponse};
12use super::tracing::TracingState;
13use crate::db::Database;
14use crate::function::{FunctionExecutor, FunctionRegistry};
15
16/// RPC handler for function invocations.
17#[derive(Clone)]
18pub struct RpcHandler {
19    /// Function executor.
20    executor: Arc<FunctionExecutor>,
21}
22
23impl RpcHandler {
24    /// Create a new RPC handler.
25    pub fn new(registry: FunctionRegistry, db: Database) -> Self {
26        let executor = FunctionExecutor::new(Arc::new(registry), db);
27        Self {
28            executor: Arc::new(executor),
29        }
30    }
31
32    /// Create a new RPC handler with dispatch capabilities.
33    pub fn with_dispatch(
34        registry: FunctionRegistry,
35        db: Database,
36        job_dispatcher: Option<Arc<dyn JobDispatch>>,
37        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
38    ) -> Self {
39        Self::with_dispatch_and_issuer(registry, db, job_dispatcher, workflow_dispatcher, None)
40    }
41
42    /// Create a new RPC handler with dispatch and token issuer.
43    pub fn with_dispatch_and_issuer(
44        registry: FunctionRegistry,
45        db: Database,
46        job_dispatcher: Option<Arc<dyn JobDispatch>>,
47        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
48        token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
49    ) -> Self {
50        let executor = FunctionExecutor::with_dispatch_and_issuer(
51            Arc::new(registry),
52            db,
53            job_dispatcher,
54            workflow_dispatcher,
55            token_issuer,
56        );
57        Self {
58            executor: Arc::new(executor),
59        }
60    }
61
62    /// Set the token TTL config. Must be called before any requests are handled.
63    pub fn set_token_ttl(&mut self, ttl: forge_core::AuthTokenTtl) {
64        if let Some(executor) = Arc::get_mut(&mut self.executor) {
65            executor.set_token_ttl(ttl);
66        }
67    }
68
69    /// Set the signals collector for auto-capturing RPC events.
70    pub fn set_signals_collector(
71        &mut self,
72        collector: crate::signals::SignalsCollector,
73        server_secret: String,
74    ) {
75        if let Some(executor) = Arc::get_mut(&mut self.executor) {
76            executor.set_signals_collector(collector, server_secret);
77        }
78    }
79
80    /// Handle an RPC request.
81    pub async fn handle(
82        &self,
83        request: RpcRequest,
84        auth: AuthContext,
85        metadata: RequestMetadata,
86    ) -> RpcResponse {
87        // Don't check has_function early - let executor try jobs/workflows too
88        match self
89            .executor
90            .execute(&request.function, request.args, auth, metadata.clone())
91            .await
92        {
93            Ok(exec_result) => RpcResponse::success(exec_result.result)
94                .with_request_id(metadata.request_id.to_string()),
95            Err(e) => RpcResponse::error(RpcError::from(e))
96                .with_request_id(metadata.request_id.to_string()),
97        }
98    }
99}
100
101use super::extract_client_ip;
102
103/// Extract user agent from headers.
104fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
105    headers
106        .get(USER_AGENT)
107        .and_then(|v| v.to_str().ok())
108        .map(String::from)
109}
110
111/// Build request metadata from tracing state and headers.
112fn build_metadata(tracing: TracingState, headers: &HeaderMap) -> RequestMetadata {
113    RequestMetadata {
114        request_id: uuid::Uuid::parse_str(&tracing.request_id)
115            .unwrap_or_else(|_| uuid::Uuid::new_v4()),
116        trace_id: tracing.trace_id,
117        client_ip: extract_client_ip(headers),
118        user_agent: extract_user_agent(headers),
119        correlation_id: extract_correlation_id(headers),
120        timestamp: chrono::Utc::now(),
121    }
122}
123
124/// Extract the correlation ID from the x-correlation-id header.
125fn extract_correlation_id(headers: &HeaderMap) -> Option<String> {
126    headers
127        .get("x-correlation-id")
128        .and_then(|v| v.to_str().ok())
129        .filter(|v| !v.is_empty() && v.len() <= 64)
130        .map(String::from)
131}
132
133/// Axum handler for POST /rpc.
134pub async fn rpc_handler(
135    State(handler): State<Arc<RpcHandler>>,
136    Extension(auth): Extension<AuthContext>,
137    Extension(tracing): Extension<TracingState>,
138    headers: HeaderMap,
139    Json(request): Json<RpcRequest>,
140) -> RpcResponse {
141    if !is_valid_function_name(&request.function) {
142        return RpcResponse::error(RpcError::validation(
143            "Invalid function name: must be 1-256 alphanumeric characters, underscores, dots, colons, or hyphens",
144        ));
145    }
146    handler
147        .handle(request, auth, build_metadata(tracing, &headers))
148        .await
149}
150
151/// Request body wrapper for REST-style RPC calls.
152#[derive(Debug, serde::Deserialize)]
153pub struct RpcFunctionBody {
154    /// Function arguments.
155    #[serde(default)]
156    pub args: serde_json::Value,
157}
158
159/// Validate that a function name contains only safe characters.
160/// Prevents log injection and unexpected behavior from special characters.
161fn is_valid_function_name(name: &str) -> bool {
162    !name.is_empty()
163        && name.len() <= 256
164        && name
165            .chars()
166            .all(|c| c.is_alphanumeric() || c == '_' || c == '.' || c == ':' || c == '-')
167}
168
169/// Axum handler for POST /rpc/:function (REST-style).
170pub async fn rpc_function_handler(
171    State(handler): State<Arc<RpcHandler>>,
172    Extension(auth): Extension<AuthContext>,
173    Extension(tracing): Extension<TracingState>,
174    headers: HeaderMap,
175    axum::extract::Path(function): axum::extract::Path<String>,
176    Json(body): Json<RpcFunctionBody>,
177) -> RpcResponse {
178    if !is_valid_function_name(&function) {
179        return RpcResponse::error(RpcError::validation(
180            "Invalid function name: must be 1-256 alphanumeric characters, underscores, dots, colons, or hyphens",
181        ));
182    }
183    let request = RpcRequest::new(function, body.args);
184    handler
185        .handle(request, auth, build_metadata(tracing, &headers))
186        .await
187}
188
189/// Maximum number of requests allowed in a single batch.
190const MAX_BATCH_SIZE: usize = 100;
191
192/// Axum handler for POST /rpc/batch.
193pub async fn rpc_batch_handler(
194    State(handler): State<Arc<RpcHandler>>,
195    Extension(auth): Extension<AuthContext>,
196    Extension(tracing): Extension<TracingState>,
197    headers: HeaderMap,
198    Json(batch): Json<BatchRpcRequest>,
199) -> BatchRpcResponse {
200    // Prevent DoS via unbounded batch size
201    if batch.requests.len() > MAX_BATCH_SIZE {
202        return BatchRpcResponse {
203            results: vec![RpcResponse::error(RpcError::validation(format!(
204                "Batch size {} exceeds maximum of {}",
205                batch.requests.len(),
206                MAX_BATCH_SIZE
207            )))],
208        };
209    }
210
211    let client_ip = extract_client_ip(&headers);
212    let user_agent = extract_user_agent(&headers);
213    let correlation_id = extract_correlation_id(&headers);
214    let mut results = Vec::with_capacity(batch.requests.len());
215
216    for request in batch.requests {
217        // Validate function names in batch requests
218        if !is_valid_function_name(&request.function) {
219            results.push(RpcResponse::error(RpcError::validation(
220                "Invalid function name: must be 1-256 alphanumeric characters, underscores, dots, colons, or hyphens",
221            )));
222            continue;
223        }
224        let metadata = RequestMetadata {
225            request_id: uuid::Uuid::new_v4(),
226            trace_id: tracing.trace_id.clone(),
227            client_ip: client_ip.clone(),
228            user_agent: user_agent.clone(),
229            correlation_id: correlation_id.clone(),
230            timestamp: chrono::Utc::now(),
231        };
232
233        let response = handler.handle(request, auth.clone(), metadata).await;
234        results.push(response);
235    }
236
237    BatchRpcResponse { results }
238}
239
240#[cfg(test)]
241#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
242mod tests {
243    use super::*;
244
245    fn create_mock_pool() -> sqlx::PgPool {
246        sqlx::postgres::PgPoolOptions::new()
247            .max_connections(1)
248            .connect_lazy("postgres://localhost/nonexistent")
249            .expect("Failed to create mock pool")
250    }
251
252    fn create_test_handler() -> RpcHandler {
253        let registry = FunctionRegistry::new();
254        let db = Database::from_pool(create_mock_pool());
255        RpcHandler::new(registry, db)
256    }
257
258    #[tokio::test]
259    async fn test_handle_unknown_function() {
260        let handler = create_test_handler();
261        let request = RpcRequest::new("unknown_function", serde_json::json!({}));
262        let auth = AuthContext::unauthenticated();
263        let metadata = RequestMetadata::new();
264
265        let response = handler.handle(request, auth, metadata).await;
266
267        assert!(!response.success);
268        assert!(response.error.is_some());
269        assert_eq!(response.error.as_ref().unwrap().code, "NOT_FOUND");
270    }
271}