Skip to main content

forge_runtime/function/
executor.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{AuthContext, ForgeError, JobDispatch, RequestMetadata, Result, WorkflowDispatch};
5use serde_json::Value;
6use tokio::time::timeout;
7use tracing::{Instrument, debug, error, info, trace, warn};
8
9use super::registry::FunctionRegistry;
10use super::router::{FunctionRouter, RouteResult};
11use crate::db::Database;
12use crate::signals::SignalsCollector;
13
14/// Executes functions with timeout and error handling.
15pub struct FunctionExecutor {
16    router: FunctionRouter,
17    registry: Arc<FunctionRegistry>,
18    default_timeout: Duration,
19    signals_collector: Option<SignalsCollector>,
20    signals_server_secret: String,
21}
22
23impl FunctionExecutor {
24    /// Create a new function executor.
25    pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
26        Self {
27            router: FunctionRouter::new(Arc::clone(&registry), db),
28            registry,
29            default_timeout: Duration::from_secs(30),
30            signals_collector: None,
31            signals_server_secret: String::new(),
32        }
33    }
34
35    /// Create a new function executor with custom timeout.
36    pub fn with_timeout(
37        registry: Arc<FunctionRegistry>,
38        db: Database,
39        default_timeout: Duration,
40    ) -> Self {
41        Self {
42            router: FunctionRouter::new(Arc::clone(&registry), db),
43            registry,
44            default_timeout,
45            signals_collector: None,
46            signals_server_secret: String::new(),
47        }
48    }
49
50    /// Create a new function executor with dispatch capabilities.
51    pub fn with_dispatch(
52        registry: Arc<FunctionRegistry>,
53        db: Database,
54        job_dispatcher: Option<Arc<dyn JobDispatch>>,
55        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
56    ) -> Self {
57        Self::with_dispatch_and_issuer(registry, db, job_dispatcher, workflow_dispatcher, None)
58    }
59
60    /// Create a function executor with dispatch and token issuer.
61    pub fn with_dispatch_and_issuer(
62        registry: Arc<FunctionRegistry>,
63        db: Database,
64        job_dispatcher: Option<Arc<dyn JobDispatch>>,
65        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
66        token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
67    ) -> Self {
68        let mut router = FunctionRouter::new(Arc::clone(&registry), db);
69        if let Some(jd) = job_dispatcher {
70            router = router.with_job_dispatcher(jd);
71        }
72        if let Some(wd) = workflow_dispatcher {
73            router = router.with_workflow_dispatcher(wd);
74        }
75        if let Some(issuer) = token_issuer {
76            router = router.with_token_issuer(issuer);
77        }
78        Self {
79            router,
80            registry,
81            default_timeout: Duration::from_secs(30),
82            signals_collector: None,
83            signals_server_secret: String::new(),
84        }
85    }
86
87    /// Set the signals collector for auto-capturing RPC events.
88    pub fn set_signals_collector(&mut self, collector: SignalsCollector, server_secret: String) {
89        self.signals_collector = Some(collector);
90        self.signals_server_secret = server_secret;
91    }
92
93    /// Set the token TTL config on the underlying router.
94    pub fn set_token_ttl(&mut self, ttl: forge_core::AuthTokenTtl) {
95        self.router.set_token_ttl(ttl);
96    }
97
98    /// Execute a function call.
99    pub async fn execute(
100        &self,
101        function_name: &str,
102        args: Value,
103        auth: AuthContext,
104        request: RequestMetadata,
105    ) -> Result<ExecutionResult> {
106        let start = std::time::Instant::now();
107        let fn_timeout = self.get_function_timeout(function_name);
108        let log_level = self.get_function_log_level(function_name);
109
110        let kind = self
111            .router
112            .get_function_kind(function_name)
113            .map(|k| k.to_string())
114            .unwrap_or_else(|| "unknown".to_string());
115
116        // Capture signal metadata before auth/request are consumed
117        let signal_ctx = self.signals_collector.as_ref().map(|_| SignalContext {
118            user_id: auth.user_id(),
119            tenant_id: auth.tenant_id(),
120            correlation_id: request.correlation_id.clone(),
121            client_ip: request.client_ip.clone(),
122            user_agent: request.user_agent.clone(),
123        });
124
125        let span = tracing::info_span!(
126            "fn.execute",
127            function = function_name,
128            fn.kind = %kind,
129        );
130
131        let result = match timeout(
132            fn_timeout,
133            self.router
134                .route(function_name, args.clone(), auth, request)
135                .instrument(span),
136        )
137        .await
138        {
139            Ok(result) => result,
140            Err(_) => {
141                let duration = start.elapsed();
142                self.log_execution(
143                    log_level,
144                    function_name,
145                    "unknown",
146                    &args,
147                    duration,
148                    false,
149                    Some(&format!("Timeout after {:?}", fn_timeout)),
150                );
151                crate::observability::record_fn_execution(
152                    function_name,
153                    &kind,
154                    false,
155                    duration.as_secs_f64(),
156                );
157                self.emit_signal(function_name, &kind, duration, false, &signal_ctx);
158                return Err(ForgeError::Timeout(format!(
159                    "Function '{}' timed out after {:?}",
160                    function_name, fn_timeout
161                )));
162            }
163        };
164
165        let duration = start.elapsed();
166
167        match result {
168            Ok(route_result) => {
169                let (result_kind, value) = match route_result {
170                    RouteResult::Query(v) => ("query", v),
171                    RouteResult::Mutation(v) => ("mutation", v),
172                    RouteResult::Job(v) => ("job", v),
173                    RouteResult::Workflow(v) => ("workflow", v),
174                };
175
176                self.log_execution(
177                    log_level,
178                    function_name,
179                    result_kind,
180                    &args,
181                    duration,
182                    true,
183                    None,
184                );
185                crate::observability::record_fn_execution(
186                    function_name,
187                    result_kind,
188                    true,
189                    duration.as_secs_f64(),
190                );
191                self.emit_signal(function_name, result_kind, duration, true, &signal_ctx);
192
193                Ok(ExecutionResult {
194                    function_name: function_name.to_string(),
195                    function_kind: result_kind.to_string(),
196                    result: value,
197                    duration,
198                    success: true,
199                    error: None,
200                })
201            }
202            Err(e) => {
203                self.log_execution(
204                    log_level,
205                    function_name,
206                    &kind,
207                    &args,
208                    duration,
209                    false,
210                    Some(&e.to_string()),
211                );
212                crate::observability::record_fn_execution(
213                    function_name,
214                    &kind,
215                    false,
216                    duration.as_secs_f64(),
217                );
218                self.emit_signal(function_name, &kind, duration, false, &signal_ctx);
219
220                Err(e)
221            }
222        }
223    }
224
225    /// Emit a signal event for RPC auto-capture.
226    fn emit_signal(
227        &self,
228        function_name: &str,
229        function_kind: &str,
230        duration: Duration,
231        success: bool,
232        ctx: &Option<SignalContext>,
233    ) {
234        let Some(collector) = &self.signals_collector else {
235            return;
236        };
237        let Some(ctx) = ctx else { return };
238
239        let is_bot = crate::signals::bot::is_bot(ctx.user_agent.as_deref());
240        let visitor_id = ctx.client_ip.as_ref().map(|_| {
241            crate::signals::visitor::generate_visitor_id(
242                ctx.client_ip.as_deref(),
243                ctx.user_agent.as_deref(),
244                &self.signals_server_secret,
245            )
246        });
247
248        let event = forge_core::signals::SignalEvent::rpc_call(
249            function_name,
250            function_kind,
251            duration.as_millis() as i32,
252            success,
253            ctx.user_id,
254            ctx.tenant_id,
255            ctx.correlation_id.clone(),
256            ctx.client_ip.clone(),
257            ctx.user_agent.clone(),
258            visitor_id,
259            is_bot,
260        );
261        collector.try_send(event);
262    }
263
264    /// Log function execution at the configured level.
265    #[allow(clippy::too_many_arguments)]
266    fn log_execution(
267        &self,
268        log_level: &str,
269        function_name: &str,
270        kind: &str,
271        input: &Value,
272        duration: Duration,
273        success: bool,
274        error: Option<&str>,
275    ) {
276        // Failures are always logged at error regardless of the function's
277        // configured log level. Successes use the configured level.
278        if !success {
279            error!(
280                function = function_name,
281                kind = kind,
282                duration_ms = duration.as_millis() as u64,
283                error = error,
284                "Function failed"
285            );
286            debug!(
287                function = function_name,
288                input = %input,
289                "Function input"
290            );
291            return;
292        }
293
294        macro_rules! log_fn {
295            ($level:ident) => {{
296                $level!(
297                    function = function_name,
298                    kind = kind,
299                    duration_ms = duration.as_millis() as u64,
300                    "Function executed"
301                );
302                debug!(
303                    function = function_name,
304                    input = %input,
305                    "Function input"
306                );
307            }};
308        }
309
310        match log_level {
311            "off" => {}
312            "error" => log_fn!(error),
313            "warn" => log_fn!(warn),
314            "info" => log_fn!(info),
315            "debug" => log_fn!(debug),
316            _ => log_fn!(trace),
317        }
318    }
319
320    /// Mutations default to "info" because writes are worth tracking.
321    /// Queries default to "debug" since they're high-volume.
322    fn get_function_log_level(&self, function_name: &str) -> &'static str {
323        self.registry
324            .get(function_name)
325            .map(|entry| {
326                entry.info().log_level.unwrap_or(match entry.kind() {
327                    forge_core::FunctionKind::Mutation => "info",
328                    forge_core::FunctionKind::Query => "debug",
329                })
330            })
331            .unwrap_or("info")
332    }
333
334    /// Get the timeout for a specific function.
335    fn get_function_timeout(&self, function_name: &str) -> Duration {
336        self.registry
337            .get(function_name)
338            .and_then(|entry| entry.info().timeout)
339            .map(Duration::from_secs)
340            .unwrap_or(self.default_timeout)
341    }
342
343    /// Look up function metadata by name.
344    pub fn function_info(&self, function_name: &str) -> Option<forge_core::FunctionInfo> {
345        self.registry.get(function_name).map(|e| e.info().clone())
346    }
347
348    /// Check if a function exists.
349    pub fn has_function(&self, function_name: &str) -> bool {
350        self.router.has_function(function_name)
351    }
352}
353
354/// Captured metadata from auth/request for signal emission.
355struct SignalContext {
356    user_id: Option<uuid::Uuid>,
357    tenant_id: Option<uuid::Uuid>,
358    correlation_id: Option<String>,
359    client_ip: Option<String>,
360    user_agent: Option<String>,
361}
362
363/// Result of executing a function.
364#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
365pub struct ExecutionResult {
366    /// Function name that was executed.
367    pub function_name: String,
368    /// Kind of function (query, mutation).
369    pub function_kind: String,
370    /// The result value (or null on error).
371    pub result: Value,
372    /// Execution duration.
373    #[serde(with = "duration_millis")]
374    pub duration: Duration,
375    /// Whether execution succeeded.
376    pub success: bool,
377    /// Error message if failed.
378    pub error: Option<String>,
379}
380
381mod duration_millis {
382    use serde::{Deserialize, Deserializer, Serializer};
383    use std::time::Duration;
384
385    pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
386    where
387        S: Serializer,
388    {
389        serializer.serialize_u64(duration.as_millis() as u64)
390    }
391
392    pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
393    where
394        D: Deserializer<'de>,
395    {
396        let millis = u64::deserialize(deserializer)?;
397        Ok(Duration::from_millis(millis))
398    }
399}
400
401#[cfg(test)]
402#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
403mod tests {
404    use super::*;
405
406    #[test]
407    fn test_execution_result_serialization() {
408        let result = ExecutionResult {
409            function_name: "get_user".to_string(),
410            function_kind: "query".to_string(),
411            result: serde_json::json!({"id": "123"}),
412            duration: Duration::from_millis(42),
413            success: true,
414            error: None,
415        };
416
417        let json = serde_json::to_string(&result).unwrap();
418        assert!(json.contains("\"duration\":42"));
419        assert!(json.contains("\"success\":true"));
420    }
421
422    #[test]
423    fn test_execution_result_round_trip() {
424        let original = ExecutionResult {
425            function_name: "create_user".to_string(),
426            function_kind: "mutation".to_string(),
427            result: serde_json::json!({"id": "456"}),
428            duration: Duration::from_millis(100),
429            success: true,
430            error: None,
431        };
432
433        let json = serde_json::to_string(&original).unwrap();
434        let deserialized: ExecutionResult = serde_json::from_str(&json).unwrap();
435
436        assert_eq!(deserialized.function_name, "create_user");
437        assert_eq!(deserialized.duration, Duration::from_millis(100));
438        assert!(deserialized.success);
439    }
440}