Skip to main content

forge_runtime/gateway/
rpc.rs

1use std::sync::Arc;
2
3use axum::{
4    Json,
5    extract::{Extension, State},
6    http::{HeaderMap, header::USER_AGENT},
7};
8use forge_core::function::{AuthContext, JobDispatch, RequestMetadata, WorkflowDispatch};
9
10use super::request::RpcRequest;
11use super::response::{RpcError, RpcResponse};
12use super::tracing::TracingState;
13use crate::db::Database;
14use crate::function::{FunctionExecutor, FunctionRegistry};
15
16/// RPC handler for function invocations.
17#[derive(Clone)]
18pub struct RpcHandler {
19    /// Function executor.
20    executor: Arc<FunctionExecutor>,
21}
22
23impl RpcHandler {
24    /// Create a new RPC handler.
25    pub fn new(registry: FunctionRegistry, db: Database) -> Self {
26        let executor = FunctionExecutor::new(Arc::new(registry), db);
27        Self {
28            executor: Arc::new(executor),
29        }
30    }
31
32    /// Create a new RPC handler with dispatch capabilities.
33    pub fn with_dispatch(
34        registry: FunctionRegistry,
35        db: Database,
36        job_dispatcher: Option<Arc<dyn JobDispatch>>,
37        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
38    ) -> Self {
39        Self::with_dispatch_and_issuer(registry, db, job_dispatcher, workflow_dispatcher, None)
40    }
41
42    /// Create a new RPC handler with dispatch and token issuer.
43    pub fn with_dispatch_and_issuer(
44        registry: FunctionRegistry,
45        db: Database,
46        job_dispatcher: Option<Arc<dyn JobDispatch>>,
47        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
48        token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
49    ) -> Self {
50        let executor = FunctionExecutor::with_dispatch_and_issuer(
51            Arc::new(registry),
52            db,
53            job_dispatcher,
54            workflow_dispatcher,
55            token_issuer,
56        );
57        Self {
58            executor: Arc::new(executor),
59        }
60    }
61
62    /// Handle an RPC request.
63    pub async fn handle(
64        &self,
65        request: RpcRequest,
66        auth: AuthContext,
67        metadata: RequestMetadata,
68    ) -> RpcResponse {
69        // Don't check has_function early - let executor try jobs/workflows too
70        match self
71            .executor
72            .execute(&request.function, request.args, auth, metadata.clone())
73            .await
74        {
75            Ok(exec_result) => RpcResponse::success(exec_result.result)
76                .with_request_id(metadata.request_id.to_string()),
77            Err(e) => RpcResponse::error(RpcError::from(e))
78                .with_request_id(metadata.request_id.to_string()),
79        }
80    }
81}
82
83/// Extract client IP from X-Forwarded-For or X-Real-IP headers.
84fn extract_client_ip(headers: &HeaderMap) -> Option<String> {
85    headers
86        .get("x-forwarded-for")
87        .and_then(|v| v.to_str().ok())
88        .map(|s| s.split(',').next().unwrap_or("").trim().to_string())
89        .filter(|s| !s.is_empty())
90        .or_else(|| {
91            headers
92                .get("x-real-ip")
93                .and_then(|v| v.to_str().ok())
94                .map(|s| s.trim().to_string())
95                .filter(|s| !s.is_empty())
96        })
97}
98
99/// Extract user agent from headers.
100fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
101    headers
102        .get(USER_AGENT)
103        .and_then(|v| v.to_str().ok())
104        .map(String::from)
105}
106
107/// Axum handler for POST /rpc.
108pub async fn rpc_handler(
109    State(handler): State<Arc<RpcHandler>>,
110    Extension(auth): Extension<AuthContext>,
111    Extension(tracing): Extension<TracingState>,
112    headers: HeaderMap,
113    Json(request): Json<RpcRequest>,
114) -> RpcResponse {
115    let metadata = RequestMetadata {
116        request_id: uuid::Uuid::parse_str(&tracing.request_id)
117            .unwrap_or_else(|_| uuid::Uuid::new_v4()),
118        trace_id: tracing.trace_id,
119        client_ip: extract_client_ip(&headers),
120        user_agent: extract_user_agent(&headers),
121        timestamp: chrono::Utc::now(),
122    };
123
124    handler.handle(request, auth, metadata).await
125}
126
127/// Request body wrapper for REST-style RPC calls.
128#[derive(Debug, serde::Deserialize)]
129pub struct RpcFunctionBody {
130    /// Function arguments.
131    #[serde(default)]
132    pub args: serde_json::Value,
133}
134
135/// Axum handler for POST /rpc/:function (REST-style).
136pub async fn rpc_function_handler(
137    State(handler): State<Arc<RpcHandler>>,
138    Extension(auth): Extension<AuthContext>,
139    Extension(tracing): Extension<TracingState>,
140    headers: HeaderMap,
141    axum::extract::Path(function): axum::extract::Path<String>,
142    Json(body): Json<RpcFunctionBody>,
143) -> RpcResponse {
144    let request = RpcRequest::new(function, body.args);
145
146    let metadata = RequestMetadata {
147        request_id: uuid::Uuid::parse_str(&tracing.request_id)
148            .unwrap_or_else(|_| uuid::Uuid::new_v4()),
149        trace_id: tracing.trace_id,
150        client_ip: extract_client_ip(&headers),
151        user_agent: extract_user_agent(&headers),
152        timestamp: chrono::Utc::now(),
153    };
154
155    handler.handle(request, auth, metadata).await
156}
157
158#[cfg(test)]
159#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
160mod tests {
161    use super::*;
162
163    fn create_mock_pool() -> sqlx::PgPool {
164        sqlx::postgres::PgPoolOptions::new()
165            .max_connections(1)
166            .connect_lazy("postgres://localhost/nonexistent")
167            .expect("Failed to create mock pool")
168    }
169
170    fn create_test_handler() -> RpcHandler {
171        let registry = FunctionRegistry::new();
172        let db = Database::from_pool(create_mock_pool());
173        RpcHandler::new(registry, db)
174    }
175
176    #[tokio::test]
177    async fn test_handle_unknown_function() {
178        let handler = create_test_handler();
179        let request = RpcRequest::new("unknown_function", serde_json::json!({}));
180        let auth = AuthContext::unauthenticated();
181        let metadata = RequestMetadata::new();
182
183        let response = handler.handle(request, auth, metadata).await;
184
185        assert!(!response.success);
186        assert!(response.error.is_some());
187        assert_eq!(response.error.as_ref().unwrap().code, "NOT_FOUND");
188    }
189}