forge_runtime/gateway/
rpc.rs

1use std::sync::Arc;
2
3use axum::{
4    Json,
5    extract::{Extension, State},
6};
7use forge_core::function::{AuthContext, JobDispatch, RequestMetadata, WorkflowDispatch};
8
9use super::request::RpcRequest;
10use super::response::{RpcError, RpcResponse};
11use super::tracing::TracingState;
12use crate::function::{FunctionExecutor, FunctionRegistry};
13
14/// RPC handler for function invocations.
15#[derive(Clone)]
16pub struct RpcHandler {
17    /// Function executor.
18    executor: Arc<FunctionExecutor>,
19}
20
21impl RpcHandler {
22    /// Create a new RPC handler.
23    pub fn new(registry: FunctionRegistry, db_pool: sqlx::PgPool) -> Self {
24        let executor = FunctionExecutor::new(Arc::new(registry), db_pool);
25        Self {
26            executor: Arc::new(executor),
27        }
28    }
29
30    /// Create a new RPC handler with dispatch capabilities.
31    pub fn with_dispatch(
32        registry: FunctionRegistry,
33        db_pool: sqlx::PgPool,
34        job_dispatcher: Option<Arc<dyn JobDispatch>>,
35        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
36    ) -> Self {
37        let executor = FunctionExecutor::with_dispatch(
38            Arc::new(registry),
39            db_pool,
40            job_dispatcher,
41            workflow_dispatcher,
42        );
43        Self {
44            executor: Arc::new(executor),
45        }
46    }
47
48    /// Handle an RPC request.
49    pub async fn handle(
50        &self,
51        request: RpcRequest,
52        auth: AuthContext,
53        metadata: RequestMetadata,
54    ) -> RpcResponse {
55        // Check if function exists
56        if !self.executor.has_function(&request.function) {
57            return RpcResponse::error(RpcError::not_found(format!(
58                "Function '{}' not found",
59                request.function
60            )))
61            .with_request_id(metadata.request_id.to_string());
62        }
63
64        // Execute function
65        match self
66            .executor
67            .execute(&request.function, request.args, auth, metadata.clone())
68            .await
69        {
70            Ok(exec_result) => {
71                if exec_result.success {
72                    RpcResponse::success(exec_result.result)
73                        .with_request_id(metadata.request_id.to_string())
74                } else {
75                    RpcResponse::error(RpcError::internal(
76                        exec_result
77                            .error
78                            .unwrap_or_else(|| "Unknown error".to_string()),
79                    ))
80                    .with_request_id(metadata.request_id.to_string())
81                }
82            }
83            Err(e) => RpcResponse::error(RpcError::from(e))
84                .with_request_id(metadata.request_id.to_string()),
85        }
86    }
87}
88
89/// Axum handler for POST /rpc.
90pub async fn rpc_handler(
91    State(handler): State<Arc<RpcHandler>>,
92    Extension(auth): Extension<AuthContext>,
93    Extension(tracing): Extension<TracingState>,
94    Json(request): Json<RpcRequest>,
95) -> RpcResponse {
96    let metadata = RequestMetadata {
97        request_id: uuid::Uuid::parse_str(&tracing.request_id)
98            .unwrap_or_else(|_| uuid::Uuid::new_v4()),
99        trace_id: tracing.trace_id,
100        client_ip: None,
101        user_agent: None,
102        timestamp: chrono::Utc::now(),
103    };
104
105    handler.handle(request, auth, metadata).await
106}
107
108/// Axum handler for POST /rpc/:function (REST-style).
109pub async fn rpc_function_handler(
110    State(handler): State<Arc<RpcHandler>>,
111    Extension(auth): Extension<AuthContext>,
112    Extension(tracing): Extension<TracingState>,
113    axum::extract::Path(function): axum::extract::Path<String>,
114    Json(args): Json<serde_json::Value>,
115) -> RpcResponse {
116    let request = RpcRequest::new(function, args);
117
118    let metadata = RequestMetadata {
119        request_id: uuid::Uuid::parse_str(&tracing.request_id)
120            .unwrap_or_else(|_| uuid::Uuid::new_v4()),
121        trace_id: tracing.trace_id,
122        client_ip: None,
123        user_agent: None,
124        timestamp: chrono::Utc::now(),
125    };
126
127    handler.handle(request, auth, metadata).await
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    fn create_mock_pool() -> sqlx::PgPool {
135        sqlx::postgres::PgPoolOptions::new()
136            .max_connections(1)
137            .connect_lazy("postgres://localhost/nonexistent")
138            .expect("Failed to create mock pool")
139    }
140
141    fn create_test_handler() -> RpcHandler {
142        let registry = FunctionRegistry::new();
143        let db_pool = create_mock_pool();
144        RpcHandler::new(registry, db_pool)
145    }
146
147    #[tokio::test]
148    async fn test_handle_unknown_function() {
149        let handler = create_test_handler();
150        let request = RpcRequest::new("unknown_function", serde_json::json!({}));
151        let auth = AuthContext::unauthenticated();
152        let metadata = RequestMetadata::new();
153
154        let response = handler.handle(request, auth, metadata).await;
155
156        assert!(!response.success);
157        assert!(response.error.is_some());
158        assert_eq!(response.error.as_ref().unwrap().code, "NOT_FOUND");
159    }
160}