forge_runtime/gateway/
rpc.rs

1use std::sync::Arc;
2
3use axum::{
4    Json,
5    extract::{Extension, State},
6    http::{HeaderMap, header::USER_AGENT},
7};
8use forge_core::function::{AuthContext, JobDispatch, RequestMetadata, WorkflowDispatch};
9
10use super::request::RpcRequest;
11use super::response::{RpcError, RpcResponse};
12use super::tracing::TracingState;
13use crate::function::{FunctionExecutor, FunctionRegistry};
14
15/// RPC handler for function invocations.
16#[derive(Clone)]
17pub struct RpcHandler {
18    /// Function executor.
19    executor: Arc<FunctionExecutor>,
20}
21
22impl RpcHandler {
23    /// Create a new RPC handler.
24    pub fn new(registry: FunctionRegistry, db_pool: sqlx::PgPool) -> Self {
25        let executor = FunctionExecutor::new(Arc::new(registry), db_pool);
26        Self {
27            executor: Arc::new(executor),
28        }
29    }
30
31    /// Create a new RPC handler with dispatch capabilities.
32    pub fn with_dispatch(
33        registry: FunctionRegistry,
34        db_pool: sqlx::PgPool,
35        job_dispatcher: Option<Arc<dyn JobDispatch>>,
36        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
37    ) -> Self {
38        let executor = FunctionExecutor::with_dispatch(
39            Arc::new(registry),
40            db_pool,
41            job_dispatcher,
42            workflow_dispatcher,
43        );
44        Self {
45            executor: Arc::new(executor),
46        }
47    }
48
49    /// Handle an RPC request.
50    pub async fn handle(
51        &self,
52        request: RpcRequest,
53        auth: AuthContext,
54        metadata: RequestMetadata,
55    ) -> RpcResponse {
56        if !self.executor.has_function(&request.function) {
57            return RpcResponse::error(RpcError::not_found(format!(
58                "Function '{}' not found",
59                request.function
60            )))
61            .with_request_id(metadata.request_id.to_string());
62        }
63
64        match self
65            .executor
66            .execute(&request.function, request.args, auth, metadata.clone())
67            .await
68        {
69            Ok(exec_result) => {
70                if exec_result.success {
71                    RpcResponse::success(exec_result.result)
72                        .with_request_id(metadata.request_id.to_string())
73                } else {
74                    RpcResponse::error(RpcError::internal(
75                        exec_result
76                            .error
77                            .unwrap_or_else(|| "Unknown error".to_string()),
78                    ))
79                    .with_request_id(metadata.request_id.to_string())
80                }
81            }
82            Err(e) => RpcResponse::error(RpcError::from(e))
83                .with_request_id(metadata.request_id.to_string()),
84        }
85    }
86}
87
88/// Extract client IP from X-Forwarded-For or X-Real-IP headers.
89fn extract_client_ip(headers: &HeaderMap) -> Option<String> {
90    headers
91        .get("x-forwarded-for")
92        .and_then(|v| v.to_str().ok())
93        .map(|s| s.split(',').next().unwrap_or("").trim().to_string())
94        .filter(|s| !s.is_empty())
95        .or_else(|| {
96            headers
97                .get("x-real-ip")
98                .and_then(|v| v.to_str().ok())
99                .map(|s| s.trim().to_string())
100                .filter(|s| !s.is_empty())
101        })
102}
103
104/// Extract user agent from headers.
105fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
106    headers
107        .get(USER_AGENT)
108        .and_then(|v| v.to_str().ok())
109        .map(String::from)
110}
111
112/// Axum handler for POST /rpc.
113pub async fn rpc_handler(
114    State(handler): State<Arc<RpcHandler>>,
115    Extension(auth): Extension<AuthContext>,
116    Extension(tracing): Extension<TracingState>,
117    headers: HeaderMap,
118    Json(request): Json<RpcRequest>,
119) -> RpcResponse {
120    let metadata = RequestMetadata {
121        request_id: uuid::Uuid::parse_str(&tracing.request_id)
122            .unwrap_or_else(|_| uuid::Uuid::new_v4()),
123        trace_id: tracing.trace_id,
124        client_ip: extract_client_ip(&headers),
125        user_agent: extract_user_agent(&headers),
126        timestamp: chrono::Utc::now(),
127    };
128
129    handler.handle(request, auth, metadata).await
130}
131
132/// Request body wrapper for REST-style RPC calls.
133#[derive(Debug, serde::Deserialize)]
134pub struct RpcFunctionBody {
135    /// Function arguments.
136    #[serde(default)]
137    pub args: serde_json::Value,
138}
139
140/// Axum handler for POST /rpc/:function (REST-style).
141pub async fn rpc_function_handler(
142    State(handler): State<Arc<RpcHandler>>,
143    Extension(auth): Extension<AuthContext>,
144    Extension(tracing): Extension<TracingState>,
145    headers: HeaderMap,
146    axum::extract::Path(function): axum::extract::Path<String>,
147    Json(body): Json<RpcFunctionBody>,
148) -> RpcResponse {
149    let request = RpcRequest::new(function, body.args);
150
151    let metadata = RequestMetadata {
152        request_id: uuid::Uuid::parse_str(&tracing.request_id)
153            .unwrap_or_else(|_| uuid::Uuid::new_v4()),
154        trace_id: tracing.trace_id,
155        client_ip: extract_client_ip(&headers),
156        user_agent: extract_user_agent(&headers),
157        timestamp: chrono::Utc::now(),
158    };
159
160    handler.handle(request, auth, metadata).await
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    fn create_mock_pool() -> sqlx::PgPool {
168        sqlx::postgres::PgPoolOptions::new()
169            .max_connections(1)
170            .connect_lazy("postgres://localhost/nonexistent")
171            .expect("Failed to create mock pool")
172    }
173
174    fn create_test_handler() -> RpcHandler {
175        let registry = FunctionRegistry::new();
176        let db_pool = create_mock_pool();
177        RpcHandler::new(registry, db_pool)
178    }
179
180    #[tokio::test]
181    async fn test_handle_unknown_function() {
182        let handler = create_test_handler();
183        let request = RpcRequest::new("unknown_function", serde_json::json!({}));
184        let auth = AuthContext::unauthenticated();
185        let metadata = RequestMetadata::new();
186
187        let response = handler.handle(request, auth, metadata).await;
188
189        assert!(!response.success);
190        assert!(response.error.is_some());
191        assert_eq!(response.error.as_ref().unwrap().code, "NOT_FOUND");
192    }
193}