Skip to main content

forge_runtime/function/
executor.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{AuthContext, ForgeError, JobDispatch, RequestMetadata, Result, WorkflowDispatch};
5use serde_json::Value;
6use tokio::time::timeout;
7use tracing::{Instrument, debug, error, info, trace, warn};
8
9use super::registry::FunctionRegistry;
10use super::router::{FunctionRouter, RouteResult};
11use crate::db::Database;
12
13/// Executes functions with timeout and error handling.
14pub struct FunctionExecutor {
15    router: FunctionRouter,
16    registry: Arc<FunctionRegistry>,
17    default_timeout: Duration,
18}
19
20impl FunctionExecutor {
21    /// Create a new function executor.
22    pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
23        Self {
24            router: FunctionRouter::new(Arc::clone(&registry), db),
25            registry,
26            default_timeout: Duration::from_secs(30),
27        }
28    }
29
30    /// Create a new function executor with custom timeout.
31    pub fn with_timeout(
32        registry: Arc<FunctionRegistry>,
33        db: Database,
34        default_timeout: Duration,
35    ) -> Self {
36        Self {
37            router: FunctionRouter::new(Arc::clone(&registry), db),
38            registry,
39            default_timeout,
40        }
41    }
42
43    /// Create a new function executor with dispatch capabilities.
44    pub fn with_dispatch(
45        registry: Arc<FunctionRegistry>,
46        db: Database,
47        job_dispatcher: Option<Arc<dyn JobDispatch>>,
48        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
49    ) -> Self {
50        Self::with_dispatch_and_issuer(registry, db, job_dispatcher, workflow_dispatcher, None)
51    }
52
53    /// Create a function executor with dispatch and token issuer.
54    pub fn with_dispatch_and_issuer(
55        registry: Arc<FunctionRegistry>,
56        db: Database,
57        job_dispatcher: Option<Arc<dyn JobDispatch>>,
58        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
59        token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
60    ) -> Self {
61        let mut router = FunctionRouter::new(Arc::clone(&registry), db);
62        if let Some(jd) = job_dispatcher {
63            router = router.with_job_dispatcher(jd);
64        }
65        if let Some(wd) = workflow_dispatcher {
66            router = router.with_workflow_dispatcher(wd);
67        }
68        if let Some(issuer) = token_issuer {
69            router = router.with_token_issuer(issuer);
70        }
71        Self {
72            router,
73            registry,
74            default_timeout: Duration::from_secs(30),
75        }
76    }
77
78    /// Execute a function call.
79    pub async fn execute(
80        &self,
81        function_name: &str,
82        args: Value,
83        auth: AuthContext,
84        request: RequestMetadata,
85    ) -> Result<ExecutionResult> {
86        let start = std::time::Instant::now();
87        let fn_timeout = self.get_function_timeout(function_name);
88        let log_level = self.get_function_log_level(function_name);
89
90        let kind = self
91            .router
92            .get_function_kind(function_name)
93            .map(|k| k.to_string())
94            .unwrap_or_else(|| "unknown".to_string());
95
96        let span = tracing::info_span!(
97            "fn.execute",
98            function = function_name,
99            kind = %kind,
100        );
101
102        let result = match timeout(
103            fn_timeout,
104            self.router
105                .route(function_name, args.clone(), auth, request)
106                .instrument(span),
107        )
108        .await
109        {
110            Ok(result) => result,
111            Err(_) => {
112                let duration = start.elapsed();
113                self.log_execution(
114                    log_level,
115                    function_name,
116                    "unknown",
117                    &args,
118                    duration,
119                    false,
120                    Some(&format!("Timeout after {:?}", fn_timeout)),
121                );
122                crate::observability::record_fn_execution(
123                    function_name,
124                    &kind,
125                    false,
126                    duration.as_secs_f64(),
127                );
128                return Err(ForgeError::Timeout(format!(
129                    "Function '{}' timed out after {:?}",
130                    function_name, fn_timeout
131                )));
132            }
133        };
134
135        let duration = start.elapsed();
136
137        match result {
138            Ok(route_result) => {
139                let (result_kind, value) = match route_result {
140                    RouteResult::Query(v) => ("query", v),
141                    RouteResult::Mutation(v) => ("mutation", v),
142                    RouteResult::Job(v) => ("job", v),
143                    RouteResult::Workflow(v) => ("workflow", v),
144                };
145
146                self.log_execution(
147                    log_level,
148                    function_name,
149                    result_kind,
150                    &args,
151                    duration,
152                    true,
153                    None,
154                );
155                crate::observability::record_fn_execution(
156                    function_name,
157                    result_kind,
158                    true,
159                    duration.as_secs_f64(),
160                );
161
162                Ok(ExecutionResult {
163                    function_name: function_name.to_string(),
164                    function_kind: result_kind.to_string(),
165                    result: value,
166                    duration,
167                    success: true,
168                    error: None,
169                })
170            }
171            Err(e) => {
172                self.log_execution(
173                    log_level,
174                    function_name,
175                    &kind,
176                    &args,
177                    duration,
178                    false,
179                    Some(&e.to_string()),
180                );
181                crate::observability::record_fn_execution(
182                    function_name,
183                    &kind,
184                    false,
185                    duration.as_secs_f64(),
186                );
187
188                Err(e)
189            }
190        }
191    }
192
193    /// Log function execution at the configured level.
194    #[allow(clippy::too_many_arguments)]
195    fn log_execution(
196        &self,
197        log_level: &str,
198        function_name: &str,
199        kind: &str,
200        input: &Value,
201        duration: Duration,
202        success: bool,
203        error: Option<&str>,
204    ) {
205        // Failures are always logged at error regardless of the function's
206        // configured log level. Successes use the configured level.
207        if !success {
208            error!(
209                function = function_name,
210                kind = kind,
211                duration_ms = duration.as_millis() as u64,
212                error = error,
213                "Function failed"
214            );
215            debug!(
216                function = function_name,
217                input = %input,
218                "Function input"
219            );
220            return;
221        }
222
223        macro_rules! log_fn {
224            ($level:ident) => {{
225                $level!(
226                    function = function_name,
227                    kind = kind,
228                    duration_ms = duration.as_millis() as u64,
229                    "Function executed"
230                );
231                debug!(
232                    function = function_name,
233                    input = %input,
234                    "Function input"
235                );
236            }};
237        }
238
239        match log_level {
240            "off" => {}
241            "error" => log_fn!(error),
242            "warn" => log_fn!(warn),
243            "info" => log_fn!(info),
244            "debug" => log_fn!(debug),
245            _ => log_fn!(trace),
246        }
247    }
248
249    /// Mutations default to "info" because writes are worth tracking.
250    /// Queries default to "debug" since they're high-volume.
251    fn get_function_log_level(&self, function_name: &str) -> &'static str {
252        self.registry
253            .get(function_name)
254            .map(|entry| {
255                entry.info().log_level.unwrap_or(match entry.kind() {
256                    forge_core::FunctionKind::Mutation => "info",
257                    forge_core::FunctionKind::Query => "debug",
258                })
259            })
260            .unwrap_or("info")
261    }
262
263    /// Get the timeout for a specific function.
264    fn get_function_timeout(&self, function_name: &str) -> Duration {
265        self.registry
266            .get(function_name)
267            .and_then(|entry| entry.info().timeout)
268            .map(Duration::from_secs)
269            .unwrap_or(self.default_timeout)
270    }
271
272    /// Check if a function exists.
273    pub fn has_function(&self, function_name: &str) -> bool {
274        self.router.has_function(function_name)
275    }
276}
277
278/// Result of executing a function.
279#[derive(Debug, Clone, serde::Serialize)]
280pub struct ExecutionResult {
281    /// Function name that was executed.
282    pub function_name: String,
283    /// Kind of function (query, mutation).
284    pub function_kind: String,
285    /// The result value (or null on error).
286    pub result: Value,
287    /// Execution duration.
288    #[serde(with = "duration_millis")]
289    pub duration: Duration,
290    /// Whether execution succeeded.
291    pub success: bool,
292    /// Error message if failed.
293    pub error: Option<String>,
294}
295
296mod duration_millis {
297    use serde::{Deserialize, Deserializer, Serializer};
298    use std::time::Duration;
299
300    pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
301    where
302        S: Serializer,
303    {
304        serializer.serialize_u64(duration.as_millis() as u64)
305    }
306
307    #[allow(dead_code)]
308    pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
309    where
310        D: Deserializer<'de>,
311    {
312        let millis = u64::deserialize(deserializer)?;
313        Ok(Duration::from_millis(millis))
314    }
315}
316
317#[cfg(test)]
318#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
319mod tests {
320    use super::*;
321
322    #[test]
323    fn test_execution_result_serialization() {
324        let result = ExecutionResult {
325            function_name: "get_user".to_string(),
326            function_kind: "query".to_string(),
327            result: serde_json::json!({"id": "123"}),
328            duration: Duration::from_millis(42),
329            success: true,
330            error: None,
331        };
332
333        let json = serde_json::to_string(&result).unwrap();
334        assert!(json.contains("\"duration\":42"));
335        assert!(json.contains("\"success\":true"));
336    }
337}