use std::sync::Arc;
use std::time::Duration;
use forge_core::{AuthContext, ForgeError, JobDispatch, RequestMetadata, Result, WorkflowDispatch};
use serde_json::Value;
use tokio::time::timeout;
use tracing::{Instrument, debug, error, info, trace, warn};
use super::registry::FunctionRegistry;
use super::router::{FunctionRouter, RouteResult};
use crate::db::Database;
use crate::signals::SignalsCollector;
pub struct FunctionExecutor {
router: FunctionRouter,
registry: Arc<FunctionRegistry>,
default_timeout: Duration,
signals_collector: Option<SignalsCollector>,
signals_server_secret: String,
}
impl FunctionExecutor {
pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
Self {
router: FunctionRouter::new(Arc::clone(®istry), db),
registry,
default_timeout: Duration::from_secs(30),
signals_collector: None,
signals_server_secret: String::new(),
}
}
pub fn with_timeout(
registry: Arc<FunctionRegistry>,
db: Database,
default_timeout: Duration,
) -> Self {
Self {
router: FunctionRouter::new(Arc::clone(®istry), db),
registry,
default_timeout,
signals_collector: None,
signals_server_secret: String::new(),
}
}
pub fn with_dispatch(
registry: Arc<FunctionRegistry>,
db: Database,
job_dispatcher: Option<Arc<dyn JobDispatch>>,
workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
) -> Self {
Self::with_dispatch_and_issuer(registry, db, job_dispatcher, workflow_dispatcher, None)
}
pub fn with_dispatch_and_issuer(
registry: Arc<FunctionRegistry>,
db: Database,
job_dispatcher: Option<Arc<dyn JobDispatch>>,
workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
) -> Self {
let mut router = FunctionRouter::new(Arc::clone(®istry), db);
if let Some(jd) = job_dispatcher {
router = router.with_job_dispatcher(jd);
}
if let Some(wd) = workflow_dispatcher {
router = router.with_workflow_dispatcher(wd);
}
if let Some(issuer) = token_issuer {
router = router.with_token_issuer(issuer);
}
Self {
router,
registry,
default_timeout: Duration::from_secs(30),
signals_collector: None,
signals_server_secret: String::new(),
}
}
pub fn set_signals_collector(&mut self, collector: SignalsCollector, server_secret: String) {
self.signals_collector = Some(collector);
self.signals_server_secret = server_secret;
}
pub fn set_token_ttl(&mut self, ttl: forge_core::AuthTokenTtl) {
self.router.set_token_ttl(ttl);
}
pub async fn execute(
&self,
function_name: &str,
args: Value,
auth: AuthContext,
request: RequestMetadata,
) -> Result<ExecutionResult> {
let start = std::time::Instant::now();
let fn_timeout = self.get_function_timeout(function_name);
let log_level = self.get_function_log_level(function_name);
let kind = self
.router
.get_function_kind(function_name)
.map(|k| k.to_string())
.unwrap_or_else(|| "unknown".to_string());
let signal_ctx = self.signals_collector.as_ref().map(|_| SignalContext {
user_id: auth.user_id(),
tenant_id: auth.tenant_id(),
correlation_id: request.correlation_id.clone(),
client_ip: request.client_ip.clone(),
user_agent: request.user_agent.clone(),
});
let span = tracing::info_span!(
"fn.execute",
function = function_name,
fn.kind = %kind,
);
let result = match timeout(
fn_timeout,
self.router
.route(function_name, args.clone(), auth, request)
.instrument(span),
)
.await
{
Ok(result) => result,
Err(_) => {
let duration = start.elapsed();
self.log_execution(
log_level,
function_name,
"unknown",
&args,
duration,
false,
Some(&format!("Timeout after {:?}", fn_timeout)),
);
crate::observability::record_fn_execution(
function_name,
&kind,
false,
duration.as_secs_f64(),
);
self.emit_signal(function_name, &kind, duration, false, &signal_ctx);
return Err(ForgeError::Timeout(format!(
"Function '{}' timed out after {:?}",
function_name, fn_timeout
)));
}
};
let duration = start.elapsed();
match result {
Ok(route_result) => {
let (result_kind, value) = match route_result {
RouteResult::Query(v) => ("query", v),
RouteResult::Mutation(v) => ("mutation", v),
RouteResult::Job(v) => ("job", v),
RouteResult::Workflow(v) => ("workflow", v),
};
self.log_execution(
log_level,
function_name,
result_kind,
&args,
duration,
true,
None,
);
crate::observability::record_fn_execution(
function_name,
result_kind,
true,
duration.as_secs_f64(),
);
self.emit_signal(function_name, result_kind, duration, true, &signal_ctx);
Ok(ExecutionResult {
function_name: function_name.to_string(),
function_kind: result_kind.to_string(),
result: value,
duration,
success: true,
error: None,
})
}
Err(e) => {
self.log_execution(
log_level,
function_name,
&kind,
&args,
duration,
false,
Some(&e.to_string()),
);
crate::observability::record_fn_execution(
function_name,
&kind,
false,
duration.as_secs_f64(),
);
self.emit_signal(function_name, &kind, duration, false, &signal_ctx);
Err(e)
}
}
}
fn emit_signal(
&self,
function_name: &str,
function_kind: &str,
duration: Duration,
success: bool,
ctx: &Option<SignalContext>,
) {
let Some(collector) = &self.signals_collector else {
return;
};
let Some(ctx) = ctx else { return };
let is_bot = crate::signals::bot::is_bot(ctx.user_agent.as_deref());
let visitor_id = ctx.client_ip.as_ref().map(|_| {
crate::signals::visitor::generate_visitor_id(
ctx.client_ip.as_deref(),
ctx.user_agent.as_deref(),
&self.signals_server_secret,
)
});
let event = forge_core::signals::SignalEvent::rpc_call(
function_name,
function_kind,
duration.as_millis() as i32,
success,
ctx.user_id,
ctx.tenant_id,
ctx.correlation_id.clone(),
ctx.client_ip.clone(),
ctx.user_agent.clone(),
visitor_id,
is_bot,
);
collector.try_send(event);
}
#[allow(clippy::too_many_arguments)]
fn log_execution(
&self,
log_level: &str,
function_name: &str,
kind: &str,
input: &Value,
duration: Duration,
success: bool,
error: Option<&str>,
) {
if !success {
error!(
function = function_name,
kind = kind,
duration_ms = duration.as_millis() as u64,
error = error,
"Function failed"
);
debug!(
function = function_name,
input = %input,
"Function input"
);
return;
}
macro_rules! log_fn {
($level:ident) => {{
$level!(
function = function_name,
kind = kind,
duration_ms = duration.as_millis() as u64,
"Function executed"
);
debug!(
function = function_name,
input = %input,
"Function input"
);
}};
}
match log_level {
"off" => {}
"error" => log_fn!(error),
"warn" => log_fn!(warn),
"info" => log_fn!(info),
"debug" => log_fn!(debug),
_ => log_fn!(trace),
}
}
fn get_function_log_level(&self, function_name: &str) -> &'static str {
self.registry
.get(function_name)
.map(|entry| {
entry.info().log_level.unwrap_or(match entry.kind() {
forge_core::FunctionKind::Mutation => "info",
forge_core::FunctionKind::Query => "debug",
})
})
.unwrap_or("info")
}
fn get_function_timeout(&self, function_name: &str) -> Duration {
self.registry
.get(function_name)
.and_then(|entry| entry.info().timeout)
.map(Duration::from_secs)
.unwrap_or(self.default_timeout)
}
pub fn function_info(&self, function_name: &str) -> Option<forge_core::FunctionInfo> {
self.registry.get(function_name).map(|e| e.info().clone())
}
pub fn has_function(&self, function_name: &str) -> bool {
self.router.has_function(function_name)
}
}
struct SignalContext {
user_id: Option<uuid::Uuid>,
tenant_id: Option<uuid::Uuid>,
correlation_id: Option<String>,
client_ip: Option<String>,
user_agent: Option<String>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ExecutionResult {
pub function_name: String,
pub function_kind: String,
pub result: Value,
#[serde(with = "duration_millis")]
pub duration: Duration,
pub success: bool,
pub error: Option<String>,
}
mod duration_millis {
use serde::{Deserialize, Deserializer, Serializer};
use std::time::Duration;
pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_u64(duration.as_millis() as u64)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
let millis = u64::deserialize(deserializer)?;
Ok(Duration::from_millis(millis))
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_execution_result_serialization() {
let result = ExecutionResult {
function_name: "get_user".to_string(),
function_kind: "query".to_string(),
result: serde_json::json!({"id": "123"}),
duration: Duration::from_millis(42),
success: true,
error: None,
};
let json = serde_json::to_string(&result).unwrap();
assert!(json.contains("\"duration\":42"));
assert!(json.contains("\"success\":true"));
}
#[test]
fn test_execution_result_round_trip() {
let original = ExecutionResult {
function_name: "create_user".to_string(),
function_kind: "mutation".to_string(),
result: serde_json::json!({"id": "456"}),
duration: Duration::from_millis(100),
success: true,
error: None,
};
let json = serde_json::to_string(&original).unwrap();
let deserialized: ExecutionResult = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.function_name, "create_user");
assert_eq!(deserialized.duration, Duration::from_millis(100));
assert!(deserialized.success);
}
}