Skip to main content

forge_runtime/gateway/
rpc.rs

1use std::sync::Arc;
2
3use axum::{
4    Json,
5    extract::{Extension, State},
6    http::{HeaderMap, header::USER_AGENT},
7};
8use forge_core::function::{AuthContext, JobDispatch, RequestMetadata, WorkflowDispatch};
9
10use super::request::RpcRequest;
11use super::response::{RpcError, RpcResponse};
12use super::tracing::TracingState;
13use crate::db::Database;
14use crate::function::{FunctionExecutor, FunctionRegistry};
15
16/// RPC handler for function invocations.
17#[derive(Clone)]
18pub struct RpcHandler {
19    /// Function executor.
20    executor: Arc<FunctionExecutor>,
21}
22
23impl RpcHandler {
24    /// Create a new RPC handler.
25    pub fn new(registry: FunctionRegistry, db: Database) -> Self {
26        let executor = FunctionExecutor::new(Arc::new(registry), db);
27        Self {
28            executor: Arc::new(executor),
29        }
30    }
31
32    /// Create a new RPC handler with dispatch capabilities.
33    pub fn with_dispatch(
34        registry: FunctionRegistry,
35        db: Database,
36        job_dispatcher: Option<Arc<dyn JobDispatch>>,
37        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
38    ) -> Self {
39        let executor = FunctionExecutor::with_dispatch(
40            Arc::new(registry),
41            db,
42            job_dispatcher,
43            workflow_dispatcher,
44        );
45        Self {
46            executor: Arc::new(executor),
47        }
48    }
49
50    /// Handle an RPC request.
51    pub async fn handle(
52        &self,
53        request: RpcRequest,
54        auth: AuthContext,
55        metadata: RequestMetadata,
56    ) -> RpcResponse {
57        // Don't check has_function early - let executor try jobs/workflows too
58        match self
59            .executor
60            .execute(&request.function, request.args, auth, metadata.clone())
61            .await
62        {
63            Ok(exec_result) => RpcResponse::success(exec_result.result)
64                .with_request_id(metadata.request_id.to_string()),
65            Err(e) => RpcResponse::error(RpcError::from(e))
66                .with_request_id(metadata.request_id.to_string()),
67        }
68    }
69}
70
71/// Extract client IP from X-Forwarded-For or X-Real-IP headers.
72fn extract_client_ip(headers: &HeaderMap) -> Option<String> {
73    headers
74        .get("x-forwarded-for")
75        .and_then(|v| v.to_str().ok())
76        .map(|s| s.split(',').next().unwrap_or("").trim().to_string())
77        .filter(|s| !s.is_empty())
78        .or_else(|| {
79            headers
80                .get("x-real-ip")
81                .and_then(|v| v.to_str().ok())
82                .map(|s| s.trim().to_string())
83                .filter(|s| !s.is_empty())
84        })
85}
86
87/// Extract user agent from headers.
88fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
89    headers
90        .get(USER_AGENT)
91        .and_then(|v| v.to_str().ok())
92        .map(String::from)
93}
94
95/// Axum handler for POST /rpc.
96pub async fn rpc_handler(
97    State(handler): State<Arc<RpcHandler>>,
98    Extension(auth): Extension<AuthContext>,
99    Extension(tracing): Extension<TracingState>,
100    headers: HeaderMap,
101    Json(request): Json<RpcRequest>,
102) -> RpcResponse {
103    let metadata = RequestMetadata {
104        request_id: uuid::Uuid::parse_str(&tracing.request_id)
105            .unwrap_or_else(|_| uuid::Uuid::new_v4()),
106        trace_id: tracing.trace_id,
107        client_ip: extract_client_ip(&headers),
108        user_agent: extract_user_agent(&headers),
109        timestamp: chrono::Utc::now(),
110    };
111
112    handler.handle(request, auth, metadata).await
113}
114
115/// Request body wrapper for REST-style RPC calls.
116#[derive(Debug, serde::Deserialize)]
117pub struct RpcFunctionBody {
118    /// Function arguments.
119    #[serde(default)]
120    pub args: serde_json::Value,
121}
122
123/// Axum handler for POST /rpc/:function (REST-style).
124pub async fn rpc_function_handler(
125    State(handler): State<Arc<RpcHandler>>,
126    Extension(auth): Extension<AuthContext>,
127    Extension(tracing): Extension<TracingState>,
128    headers: HeaderMap,
129    axum::extract::Path(function): axum::extract::Path<String>,
130    Json(body): Json<RpcFunctionBody>,
131) -> RpcResponse {
132    let request = RpcRequest::new(function, body.args);
133
134    let metadata = RequestMetadata {
135        request_id: uuid::Uuid::parse_str(&tracing.request_id)
136            .unwrap_or_else(|_| uuid::Uuid::new_v4()),
137        trace_id: tracing.trace_id,
138        client_ip: extract_client_ip(&headers),
139        user_agent: extract_user_agent(&headers),
140        timestamp: chrono::Utc::now(),
141    };
142
143    handler.handle(request, auth, metadata).await
144}
145
146#[cfg(test)]
147#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
148mod tests {
149    use super::*;
150
151    fn create_mock_pool() -> sqlx::PgPool {
152        sqlx::postgres::PgPoolOptions::new()
153            .max_connections(1)
154            .connect_lazy("postgres://localhost/nonexistent")
155            .expect("Failed to create mock pool")
156    }
157
158    fn create_test_handler() -> RpcHandler {
159        let registry = FunctionRegistry::new();
160        let db = Database::from_pool(create_mock_pool());
161        RpcHandler::new(registry, db)
162    }
163
164    #[tokio::test]
165    async fn test_handle_unknown_function() {
166        let handler = create_test_handler();
167        let request = RpcRequest::new("unknown_function", serde_json::json!({}));
168        let auth = AuthContext::unauthenticated();
169        let metadata = RequestMetadata::new();
170
171        let response = handler.handle(request, auth, metadata).await;
172
173        assert!(!response.success);
174        assert!(response.error.is_some());
175        assert_eq!(response.error.as_ref().unwrap().code, "NOT_FOUND");
176    }
177}