Skip to main content

forge_runtime/gateway/
rpc.rs

1use std::sync::Arc;
2
3use axum::{
4    Json,
5    extract::{Extension, State},
6    http::{HeaderMap, header::USER_AGENT},
7};
8use forge_core::function::{AuthContext, JobDispatch, RequestMetadata, WorkflowDispatch};
9
10use super::request::RpcRequest;
11use super::response::{RpcError, RpcResponse};
12use super::tracing::TracingState;
13use crate::function::{FunctionExecutor, FunctionRegistry};
14
15/// RPC handler for function invocations.
16#[derive(Clone)]
17pub struct RpcHandler {
18    /// Function executor.
19    executor: Arc<FunctionExecutor>,
20}
21
22impl RpcHandler {
23    /// Create a new RPC handler.
24    pub fn new(registry: FunctionRegistry, db_pool: sqlx::PgPool) -> Self {
25        let executor = FunctionExecutor::new(Arc::new(registry), db_pool);
26        Self {
27            executor: Arc::new(executor),
28        }
29    }
30
31    /// Create a new RPC handler with dispatch capabilities.
32    pub fn with_dispatch(
33        registry: FunctionRegistry,
34        db_pool: sqlx::PgPool,
35        job_dispatcher: Option<Arc<dyn JobDispatch>>,
36        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
37    ) -> Self {
38        let executor = FunctionExecutor::with_dispatch(
39            Arc::new(registry),
40            db_pool,
41            job_dispatcher,
42            workflow_dispatcher,
43        );
44        Self {
45            executor: Arc::new(executor),
46        }
47    }
48
49    /// Handle an RPC request.
50    pub async fn handle(
51        &self,
52        request: RpcRequest,
53        auth: AuthContext,
54        metadata: RequestMetadata,
55    ) -> RpcResponse {
56        // Don't check has_function early - let executor try jobs/workflows too
57        match self
58            .executor
59            .execute(&request.function, request.args, auth, metadata.clone())
60            .await
61        {
62            Ok(exec_result) => RpcResponse::success(exec_result.result)
63                .with_request_id(metadata.request_id.to_string()),
64            Err(e) => RpcResponse::error(RpcError::from(e))
65                .with_request_id(metadata.request_id.to_string()),
66        }
67    }
68}
69
70/// Extract client IP from X-Forwarded-For or X-Real-IP headers.
71fn extract_client_ip(headers: &HeaderMap) -> Option<String> {
72    headers
73        .get("x-forwarded-for")
74        .and_then(|v| v.to_str().ok())
75        .map(|s| s.split(',').next().unwrap_or("").trim().to_string())
76        .filter(|s| !s.is_empty())
77        .or_else(|| {
78            headers
79                .get("x-real-ip")
80                .and_then(|v| v.to_str().ok())
81                .map(|s| s.trim().to_string())
82                .filter(|s| !s.is_empty())
83        })
84}
85
86/// Extract user agent from headers.
87fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
88    headers
89        .get(USER_AGENT)
90        .and_then(|v| v.to_str().ok())
91        .map(String::from)
92}
93
94/// Axum handler for POST /rpc.
95pub async fn rpc_handler(
96    State(handler): State<Arc<RpcHandler>>,
97    Extension(auth): Extension<AuthContext>,
98    Extension(tracing): Extension<TracingState>,
99    headers: HeaderMap,
100    Json(request): Json<RpcRequest>,
101) -> RpcResponse {
102    let metadata = RequestMetadata {
103        request_id: uuid::Uuid::parse_str(&tracing.request_id)
104            .unwrap_or_else(|_| uuid::Uuid::new_v4()),
105        trace_id: tracing.trace_id,
106        client_ip: extract_client_ip(&headers),
107        user_agent: extract_user_agent(&headers),
108        timestamp: chrono::Utc::now(),
109    };
110
111    handler.handle(request, auth, metadata).await
112}
113
114/// Request body wrapper for REST-style RPC calls.
115#[derive(Debug, serde::Deserialize)]
116pub struct RpcFunctionBody {
117    /// Function arguments.
118    #[serde(default)]
119    pub args: serde_json::Value,
120}
121
122/// Axum handler for POST /rpc/:function (REST-style).
123pub async fn rpc_function_handler(
124    State(handler): State<Arc<RpcHandler>>,
125    Extension(auth): Extension<AuthContext>,
126    Extension(tracing): Extension<TracingState>,
127    headers: HeaderMap,
128    axum::extract::Path(function): axum::extract::Path<String>,
129    Json(body): Json<RpcFunctionBody>,
130) -> RpcResponse {
131    let request = RpcRequest::new(function, body.args);
132
133    let metadata = RequestMetadata {
134        request_id: uuid::Uuid::parse_str(&tracing.request_id)
135            .unwrap_or_else(|_| uuid::Uuid::new_v4()),
136        trace_id: tracing.trace_id,
137        client_ip: extract_client_ip(&headers),
138        user_agent: extract_user_agent(&headers),
139        timestamp: chrono::Utc::now(),
140    };
141
142    handler.handle(request, auth, metadata).await
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    fn create_mock_pool() -> sqlx::PgPool {
150        sqlx::postgres::PgPoolOptions::new()
151            .max_connections(1)
152            .connect_lazy("postgres://localhost/nonexistent")
153            .expect("Failed to create mock pool")
154    }
155
156    fn create_test_handler() -> RpcHandler {
157        let registry = FunctionRegistry::new();
158        let db_pool = create_mock_pool();
159        RpcHandler::new(registry, db_pool)
160    }
161
162    #[tokio::test]
163    async fn test_handle_unknown_function() {
164        let handler = create_test_handler();
165        let request = RpcRequest::new("unknown_function", serde_json::json!({}));
166        let auth = AuthContext::unauthenticated();
167        let metadata = RequestMetadata::new();
168
169        let response = handler.handle(request, auth, metadata).await;
170
171        assert!(!response.success);
172        assert!(response.error.is_some());
173        assert_eq!(response.error.as_ref().unwrap().code, "NOT_FOUND");
174    }
175}