forge_runtime/gateway/
rpc.rs1use std::sync::Arc;
2
3use axum::{
4 Json,
5 extract::{Extension, State},
6};
7use forge_core::function::{AuthContext, JobDispatch, RequestMetadata, WorkflowDispatch};
8
9use super::request::RpcRequest;
10use super::response::{RpcError, RpcResponse};
11use super::tracing::TracingState;
12use crate::function::{FunctionExecutor, FunctionRegistry};
13
14#[derive(Clone)]
16pub struct RpcHandler {
17 executor: Arc<FunctionExecutor>,
19}
20
21impl RpcHandler {
22 pub fn new(registry: FunctionRegistry, db_pool: sqlx::PgPool) -> Self {
24 let executor = FunctionExecutor::new(Arc::new(registry), db_pool);
25 Self {
26 executor: Arc::new(executor),
27 }
28 }
29
30 pub fn with_dispatch(
32 registry: FunctionRegistry,
33 db_pool: sqlx::PgPool,
34 job_dispatcher: Option<Arc<dyn JobDispatch>>,
35 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
36 ) -> Self {
37 let executor = FunctionExecutor::with_dispatch(
38 Arc::new(registry),
39 db_pool,
40 job_dispatcher,
41 workflow_dispatcher,
42 );
43 Self {
44 executor: Arc::new(executor),
45 }
46 }
47
48 pub async fn handle(
50 &self,
51 request: RpcRequest,
52 auth: AuthContext,
53 metadata: RequestMetadata,
54 ) -> RpcResponse {
55 if !self.executor.has_function(&request.function) {
57 return RpcResponse::error(RpcError::not_found(format!(
58 "Function '{}' not found",
59 request.function
60 )))
61 .with_request_id(metadata.request_id.to_string());
62 }
63
64 match self
66 .executor
67 .execute(&request.function, request.args, auth, metadata.clone())
68 .await
69 {
70 Ok(exec_result) => {
71 if exec_result.success {
72 RpcResponse::success(exec_result.result)
73 .with_request_id(metadata.request_id.to_string())
74 } else {
75 RpcResponse::error(RpcError::internal(
76 exec_result
77 .error
78 .unwrap_or_else(|| "Unknown error".to_string()),
79 ))
80 .with_request_id(metadata.request_id.to_string())
81 }
82 }
83 Err(e) => RpcResponse::error(RpcError::from(e))
84 .with_request_id(metadata.request_id.to_string()),
85 }
86 }
87}
88
89pub async fn rpc_handler(
91 State(handler): State<Arc<RpcHandler>>,
92 Extension(auth): Extension<AuthContext>,
93 Extension(tracing): Extension<TracingState>,
94 Json(request): Json<RpcRequest>,
95) -> RpcResponse {
96 let metadata = RequestMetadata {
97 request_id: uuid::Uuid::parse_str(&tracing.request_id)
98 .unwrap_or_else(|_| uuid::Uuid::new_v4()),
99 trace_id: tracing.trace_id,
100 client_ip: None,
101 user_agent: None,
102 timestamp: chrono::Utc::now(),
103 };
104
105 handler.handle(request, auth, metadata).await
106}
107
108pub async fn rpc_function_handler(
110 State(handler): State<Arc<RpcHandler>>,
111 Extension(auth): Extension<AuthContext>,
112 Extension(tracing): Extension<TracingState>,
113 axum::extract::Path(function): axum::extract::Path<String>,
114 Json(args): Json<serde_json::Value>,
115) -> RpcResponse {
116 let request = RpcRequest::new(function, args);
117
118 let metadata = RequestMetadata {
119 request_id: uuid::Uuid::parse_str(&tracing.request_id)
120 .unwrap_or_else(|_| uuid::Uuid::new_v4()),
121 trace_id: tracing.trace_id,
122 client_ip: None,
123 user_agent: None,
124 timestamp: chrono::Utc::now(),
125 };
126
127 handler.handle(request, auth, metadata).await
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 fn create_mock_pool() -> sqlx::PgPool {
135 sqlx::postgres::PgPoolOptions::new()
136 .max_connections(1)
137 .connect_lazy("postgres://localhost/nonexistent")
138 .expect("Failed to create mock pool")
139 }
140
141 fn create_test_handler() -> RpcHandler {
142 let registry = FunctionRegistry::new();
143 let db_pool = create_mock_pool();
144 RpcHandler::new(registry, db_pool)
145 }
146
147 #[tokio::test]
148 async fn test_handle_unknown_function() {
149 let handler = create_test_handler();
150 let request = RpcRequest::new("unknown_function", serde_json::json!({}));
151 let auth = AuthContext::unauthenticated();
152 let metadata = RequestMetadata::new();
153
154 let response = handler.handle(request, auth, metadata).await;
155
156 assert!(!response.success);
157 assert!(response.error.is_some());
158 assert_eq!(response.error.as_ref().unwrap().code, "NOT_FOUND");
159 }
160}