use std::sync::Arc;
use axum::{
Json,
extract::{Extension, State},
http::{HeaderMap, header::USER_AGENT},
};
use forge_core::function::{
AuthContext, FunctionInfo, JobDispatch, RequestMetadata, WorkflowDispatch,
};
use super::request::{BatchRpcRequest, BatchRpcResponse, RpcRequest};
use super::response::{RpcError, RpcResponse};
use super::tracing::TracingState;
use crate::db::Database;
use crate::function::{FunctionExecutor, FunctionRegistry};
#[derive(Clone)]
pub struct RpcHandler {
executor: Arc<FunctionExecutor>,
}
impl RpcHandler {
pub fn new(registry: FunctionRegistry, db: Database) -> Self {
let executor = FunctionExecutor::new(Arc::new(registry), db);
Self {
executor: Arc::new(executor),
}
}
pub fn with_dispatch(
registry: FunctionRegistry,
db: Database,
job_dispatcher: Option<Arc<dyn JobDispatch>>,
workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
) -> Self {
Self::with_dispatch_and_issuer(registry, db, job_dispatcher, workflow_dispatcher, None)
}
pub fn with_dispatch_and_issuer(
registry: FunctionRegistry,
db: Database,
job_dispatcher: Option<Arc<dyn JobDispatch>>,
workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
) -> Self {
let executor = FunctionExecutor::with_dispatch_and_issuer(
Arc::new(registry),
db,
job_dispatcher,
workflow_dispatcher,
token_issuer,
);
Self {
executor: Arc::new(executor),
}
}
pub fn set_token_ttl(&mut self, ttl: forge_core::AuthTokenTtl) {
if let Some(executor) = Arc::get_mut(&mut self.executor) {
executor.set_token_ttl(ttl);
}
}
pub fn function_info(&self, name: &str) -> Option<FunctionInfo> {
self.executor.function_info(name)
}
pub fn set_signals_collector(
&mut self,
collector: crate::signals::SignalsCollector,
server_secret: String,
) {
if let Some(executor) = Arc::get_mut(&mut self.executor) {
executor.set_signals_collector(collector, server_secret);
}
}
pub async fn handle(
&self,
request: RpcRequest,
auth: AuthContext,
metadata: RequestMetadata,
) -> RpcResponse {
match self
.executor
.execute(&request.function, request.args, auth, metadata.clone())
.await
{
Ok(exec_result) => RpcResponse::success(exec_result.result)
.with_request_id(metadata.request_id.to_string()),
Err(e) => RpcResponse::error(RpcError::from(e))
.with_request_id(metadata.request_id.to_string()),
}
}
}
use super::extract_client_ip;
fn extract_user_agent(headers: &HeaderMap) -> Option<String> {
headers
.get(USER_AGENT)
.and_then(|v| v.to_str().ok())
.map(String::from)
}
fn build_metadata(tracing: TracingState, headers: &HeaderMap) -> RequestMetadata {
RequestMetadata {
request_id: uuid::Uuid::parse_str(&tracing.request_id)
.unwrap_or_else(|_| uuid::Uuid::new_v4()),
trace_id: tracing.trace_id,
client_ip: extract_client_ip(headers),
user_agent: extract_user_agent(headers),
correlation_id: extract_correlation_id(headers),
timestamp: chrono::Utc::now(),
}
}
fn extract_correlation_id(headers: &HeaderMap) -> Option<String> {
headers
.get("x-correlation-id")
.and_then(|v| v.to_str().ok())
.filter(|v| !v.is_empty() && v.len() <= 64)
.map(String::from)
}
pub async fn rpc_handler(
State(handler): State<Arc<RpcHandler>>,
Extension(auth): Extension<AuthContext>,
Extension(tracing): Extension<TracingState>,
headers: HeaderMap,
Json(request): Json<RpcRequest>,
) -> RpcResponse {
if !is_valid_function_name(&request.function) {
return RpcResponse::error(RpcError::validation(
"Invalid function name: must be 1-256 alphanumeric characters, underscores, dots, colons, or hyphens",
));
}
handler
.handle(request, auth, build_metadata(tracing, &headers))
.await
}
#[derive(Debug, serde::Deserialize)]
pub struct RpcFunctionBody {
#[serde(default)]
pub args: serde_json::Value,
}
fn is_valid_function_name(name: &str) -> bool {
!name.is_empty()
&& name.len() <= 256
&& name
.chars()
.all(|c| c.is_alphanumeric() || c == '_' || c == '.' || c == ':' || c == '-')
}
pub async fn rpc_function_handler(
State(handler): State<Arc<RpcHandler>>,
Extension(auth): Extension<AuthContext>,
Extension(tracing): Extension<TracingState>,
headers: HeaderMap,
axum::extract::Path(function): axum::extract::Path<String>,
Json(body): Json<RpcFunctionBody>,
) -> RpcResponse {
if !is_valid_function_name(&function) {
return RpcResponse::error(RpcError::validation(
"Invalid function name: must be 1-256 alphanumeric characters, underscores, dots, colons, or hyphens",
));
}
let request = RpcRequest::new(function, body.args);
handler
.handle(request, auth, build_metadata(tracing, &headers))
.await
}
const MAX_BATCH_SIZE: usize = 100;
pub async fn rpc_batch_handler(
State(handler): State<Arc<RpcHandler>>,
Extension(auth): Extension<AuthContext>,
Extension(tracing): Extension<TracingState>,
headers: HeaderMap,
Json(batch): Json<BatchRpcRequest>,
) -> BatchRpcResponse {
if batch.requests.len() > MAX_BATCH_SIZE {
return BatchRpcResponse {
results: vec![RpcResponse::error(RpcError::validation(format!(
"Batch size {} exceeds maximum of {}",
batch.requests.len(),
MAX_BATCH_SIZE
)))],
};
}
let client_ip = extract_client_ip(&headers);
let user_agent = extract_user_agent(&headers);
let correlation_id = extract_correlation_id(&headers);
let mut results = Vec::with_capacity(batch.requests.len());
for request in batch.requests {
if !is_valid_function_name(&request.function) {
results.push(RpcResponse::error(RpcError::validation(
"Invalid function name: must be 1-256 alphanumeric characters, underscores, dots, colons, or hyphens",
)));
continue;
}
let metadata = RequestMetadata {
request_id: uuid::Uuid::new_v4(),
trace_id: tracing.trace_id.clone(),
client_ip: client_ip.clone(),
user_agent: user_agent.clone(),
correlation_id: correlation_id.clone(),
timestamp: chrono::Utc::now(),
};
let response = handler.handle(request, auth.clone(), metadata).await;
results.push(response);
}
BatchRpcResponse { results }
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
mod tests {
use super::*;
fn create_mock_pool() -> sqlx::PgPool {
sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.connect_lazy("postgres://localhost/nonexistent")
.expect("Failed to create mock pool")
}
fn create_test_handler() -> RpcHandler {
let registry = FunctionRegistry::new();
let db = Database::from_pool(create_mock_pool());
RpcHandler::new(registry, db)
}
#[tokio::test]
async fn test_handle_unknown_function() {
let handler = create_test_handler();
let request = RpcRequest::new("unknown_function", serde_json::json!({}));
let auth = AuthContext::unauthenticated();
let metadata = RequestMetadata::new();
let response = handler.handle(request, auth, metadata).await;
assert!(!response.success);
assert!(response.error.is_some());
assert_eq!(response.error.as_ref().unwrap().code, "NOT_FOUND");
}
}