Skip to main content

forge_runtime/gateway/
rpc.rs

1use std::sync::Arc;
2
3use axum::{
4    Json,
5    extract::{Extension, State},
6    http::{HeaderMap, header::USER_AGENT},
7};
8use forge_core::function::{AuthContext, JobDispatch, RequestMetadata, WorkflowDispatch};
9
10use super::request::{BatchRpcRequest, BatchRpcResponse, RpcRequest};
11use super::response::{RpcError, RpcResponse};
12use super::tracing::TracingState;
13use crate::db::Database;
14use crate::function::{FunctionExecutor, FunctionRegistry};
15
16/// RPC handler for function invocations.
17#[derive(Clone)]
18pub struct RpcHandler {
19    /// Function executor.
20    executor: Arc<FunctionExecutor>,
21}
22
23impl RpcHandler {
24    /// Create a new RPC handler.
25    pub fn new(registry: FunctionRegistry, db: Database) -> Self {
26        let executor = FunctionExecutor::new(Arc::new(registry), db);
27        Self {
28            executor: Arc::new(executor),
29        }
30    }
31
32    /// Create a new RPC handler with dispatch capabilities.
33    pub fn with_dispatch(
34        registry: FunctionRegistry,
35        db: Database,
36        job_dispatcher: Option<Arc<dyn JobDispatch>>,
37        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
38    ) -> Self {
39        Self::with_dispatch_and_issuer(registry, db, job_dispatcher, workflow_dispatcher, None)
40    }
41
42    /// Create a new RPC handler with dispatch and token issuer.
43    pub fn with_dispatch_and_issuer(
44        registry: FunctionRegistry,
45        db: Database,
46        job_dispatcher: Option<Arc<dyn JobDispatch>>,
47        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
48        token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
49    ) -> Self {
50        let executor = FunctionExecutor::with_dispatch_and_issuer(
51            Arc::new(registry),
52            db,
53            job_dispatcher,
54            workflow_dispatcher,
55            token_issuer,
56        );
57        Self {
58            executor: Arc::new(executor),
59        }
60    }
61
62    /// Set the token TTL config. Must be called before any requests are handled.
63    pub fn set_token_ttl(&mut self, ttl: forge_core::AuthTokenTtl) {
64        if let Some(executor) = Arc::get_mut(&mut self.executor) {
65            executor.set_token_ttl(ttl);
66        }
67    }
68
69    /// Handle an RPC request.
70    pub async fn handle(
71        &self,
72        request: RpcRequest,
73        auth: AuthContext,
74        metadata: RequestMetadata,
75    ) -> RpcResponse {
76        // Don't check has_function early - let executor try jobs/workflows too
77        match self
78            .executor
79            .execute(&request.function, request.args, auth, metadata.clone())
80            .await
81        {
82            Ok(exec_result) => RpcResponse::success(exec_result.result)
83                .with_request_id(metadata.request_id.to_string()),
84            Err(e) => RpcResponse::error(RpcError::from(e))
85                .with_request_id(metadata.request_id.to_string()),
86        }
87    }
88}
89
90/// Extract client IP from X-Forwarded-For or X-Real-IP headers.
91fn extract_client_ip(headers: &HeaderMap) -> Option<String> {
92    headers
93        .get("x-forwarded-for")
94        .and_then(|v| v.to_str().ok())
95        .map(|s| s.split(',').next().unwrap_or("").trim().to_string())
96        .filter(|s| !s.is_empty())
97        .or_else(|| {
98            headers
99                .get("x-real-ip")
100                .and_then(|v| v.to_str().ok())
101                .map(|s| s.trim().to_string())
102                .filter(|s| !s.is_empty())
103        })
104}
105
106/// Extract user agent from headers.
107fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
108    headers
109        .get(USER_AGENT)
110        .and_then(|v| v.to_str().ok())
111        .map(String::from)
112}
113
114/// Build request metadata from tracing state and headers.
115fn build_metadata(tracing: TracingState, headers: &HeaderMap) -> RequestMetadata {
116    RequestMetadata {
117        request_id: uuid::Uuid::parse_str(&tracing.request_id)
118            .unwrap_or_else(|_| uuid::Uuid::new_v4()),
119        trace_id: tracing.trace_id,
120        client_ip: extract_client_ip(headers),
121        user_agent: extract_user_agent(headers),
122        timestamp: chrono::Utc::now(),
123    }
124}
125
126/// Axum handler for POST /rpc.
127pub async fn rpc_handler(
128    State(handler): State<Arc<RpcHandler>>,
129    Extension(auth): Extension<AuthContext>,
130    Extension(tracing): Extension<TracingState>,
131    headers: HeaderMap,
132    Json(request): Json<RpcRequest>,
133) -> RpcResponse {
134    handler
135        .handle(request, auth, build_metadata(tracing, &headers))
136        .await
137}
138
139/// Request body wrapper for REST-style RPC calls.
140#[derive(Debug, serde::Deserialize)]
141pub struct RpcFunctionBody {
142    /// Function arguments.
143    #[serde(default)]
144    pub args: serde_json::Value,
145}
146
147/// Axum handler for POST /rpc/:function (REST-style).
148pub async fn rpc_function_handler(
149    State(handler): State<Arc<RpcHandler>>,
150    Extension(auth): Extension<AuthContext>,
151    Extension(tracing): Extension<TracingState>,
152    headers: HeaderMap,
153    axum::extract::Path(function): axum::extract::Path<String>,
154    Json(body): Json<RpcFunctionBody>,
155) -> RpcResponse {
156    let request = RpcRequest::new(function, body.args);
157    handler
158        .handle(request, auth, build_metadata(tracing, &headers))
159        .await
160}
161
162/// Axum handler for POST /rpc/batch.
163pub async fn rpc_batch_handler(
164    State(handler): State<Arc<RpcHandler>>,
165    Extension(auth): Extension<AuthContext>,
166    Extension(tracing): Extension<TracingState>,
167    headers: HeaderMap,
168    Json(batch): Json<BatchRpcRequest>,
169) -> BatchRpcResponse {
170    let client_ip = extract_client_ip(&headers);
171    let user_agent = extract_user_agent(&headers);
172    let mut results = Vec::with_capacity(batch.requests.len());
173
174    for request in batch.requests {
175        let metadata = RequestMetadata {
176            request_id: uuid::Uuid::new_v4(),
177            trace_id: tracing.trace_id.clone(),
178            client_ip: client_ip.clone(),
179            user_agent: user_agent.clone(),
180            timestamp: chrono::Utc::now(),
181        };
182
183        let response = handler.handle(request, auth.clone(), metadata).await;
184        results.push(response);
185    }
186
187    BatchRpcResponse { results }
188}
189
190#[cfg(test)]
191#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
192mod tests {
193    use super::*;
194
195    fn create_mock_pool() -> sqlx::PgPool {
196        sqlx::postgres::PgPoolOptions::new()
197            .max_connections(1)
198            .connect_lazy("postgres://localhost/nonexistent")
199            .expect("Failed to create mock pool")
200    }
201
202    fn create_test_handler() -> RpcHandler {
203        let registry = FunctionRegistry::new();
204        let db = Database::from_pool(create_mock_pool());
205        RpcHandler::new(registry, db)
206    }
207
208    #[tokio::test]
209    async fn test_handle_unknown_function() {
210        let handler = create_test_handler();
211        let request = RpcRequest::new("unknown_function", serde_json::json!({}));
212        let auth = AuthContext::unauthenticated();
213        let metadata = RequestMetadata::new();
214
215        let response = handler.handle(request, auth, metadata).await;
216
217        assert!(!response.success);
218        assert!(response.error.is_some());
219        assert_eq!(response.error.as_ref().unwrap().code, "NOT_FOUND");
220    }
221}