forge_runtime/function/
executor.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{AuthContext, ForgeError, JobDispatch, RequestMetadata, Result, WorkflowDispatch};
5use serde_json::Value;
6use tokio::time::timeout;
7use tracing::{debug, error, info, trace, warn};
8
9use super::registry::FunctionRegistry;
10use super::router::{FunctionRouter, RouteResult};
11
12/// Executes functions with timeout and error handling.
13pub struct FunctionExecutor {
14    router: FunctionRouter,
15    registry: Arc<FunctionRegistry>,
16    default_timeout: Duration,
17}
18
19impl FunctionExecutor {
20    /// Create a new function executor.
21    pub fn new(registry: Arc<FunctionRegistry>, db_pool: sqlx::PgPool) -> Self {
22        Self {
23            router: FunctionRouter::new(Arc::clone(&registry), db_pool),
24            registry,
25            default_timeout: Duration::from_secs(30),
26        }
27    }
28
29    /// Create a new function executor with custom timeout.
30    pub fn with_timeout(
31        registry: Arc<FunctionRegistry>,
32        db_pool: sqlx::PgPool,
33        default_timeout: Duration,
34    ) -> Self {
35        Self {
36            router: FunctionRouter::new(Arc::clone(&registry), db_pool),
37            registry,
38            default_timeout,
39        }
40    }
41
42    /// Create a new function executor with dispatch capabilities.
43    pub fn with_dispatch(
44        registry: Arc<FunctionRegistry>,
45        db_pool: sqlx::PgPool,
46        job_dispatcher: Option<Arc<dyn JobDispatch>>,
47        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
48    ) -> Self {
49        let mut router = FunctionRouter::new(Arc::clone(&registry), db_pool);
50        if let Some(jd) = job_dispatcher {
51            router = router.with_job_dispatcher(jd);
52        }
53        if let Some(wd) = workflow_dispatcher {
54            router = router.with_workflow_dispatcher(wd);
55        }
56        Self {
57            router,
58            registry,
59            default_timeout: Duration::from_secs(30),
60        }
61    }
62
63    /// Execute a function call.
64    pub async fn execute(
65        &self,
66        function_name: &str,
67        args: Value,
68        auth: AuthContext,
69        request: RequestMetadata,
70    ) -> Result<ExecutionResult> {
71        let start = std::time::Instant::now();
72
73        // Get function-specific timeout or use default
74        let fn_timeout = self.get_function_timeout(function_name);
75
76        // Get log level for this function (default to trace)
77        let log_level = self.get_function_log_level(function_name);
78
79        // Execute with timeout
80        let result = match timeout(
81            fn_timeout,
82            self.router
83                .route(function_name, args.clone(), auth, request),
84        )
85        .await
86        {
87            Ok(result) => result,
88            Err(_) => {
89                let duration = start.elapsed();
90                self.log_execution(
91                    log_level,
92                    function_name,
93                    "unknown",
94                    &args,
95                    duration,
96                    false,
97                    Some(&format!("Timeout after {:?}", fn_timeout)),
98                );
99                return Err(ForgeError::Timeout(format!(
100                    "Function '{}' timed out after {:?}",
101                    function_name, fn_timeout
102                )));
103            }
104        };
105
106        let duration = start.elapsed();
107
108        match result {
109            Ok(route_result) => {
110                let (kind, value) = match route_result {
111                    RouteResult::Query(v) => ("query", v),
112                    RouteResult::Mutation(v) => ("mutation", v),
113                };
114
115                self.log_execution(log_level, function_name, kind, &args, duration, true, None);
116
117                Ok(ExecutionResult {
118                    function_name: function_name.to_string(),
119                    function_kind: kind.to_string(),
120                    result: value,
121                    duration,
122                    success: true,
123                    error: None,
124                })
125            }
126            Err(e) => {
127                let kind = self
128                    .router
129                    .get_function_kind(function_name)
130                    .map(|k| k.to_string())
131                    .unwrap_or_else(|| "unknown".to_string());
132
133                self.log_execution(
134                    log_level,
135                    function_name,
136                    &kind,
137                    &args,
138                    duration,
139                    false,
140                    Some(&e.to_string()),
141                );
142
143                Ok(ExecutionResult {
144                    function_name: function_name.to_string(),
145                    function_kind: kind,
146                    result: Value::Null,
147                    duration,
148                    success: false,
149                    error: Some(e.to_string()),
150                })
151            }
152        }
153    }
154
155    /// Log function execution at the configured level.
156    #[allow(clippy::too_many_arguments)]
157    fn log_execution(
158        &self,
159        log_level: &str,
160        function_name: &str,
161        kind: &str,
162        input: &Value,
163        duration: Duration,
164        success: bool,
165        error: Option<&str>,
166    ) {
167        let duration_ms = duration.as_millis();
168        let input_str = input.to_string();
169
170        match log_level {
171            "off" => {}
172            "error" => {
173                if success {
174                    error!(
175                        function = function_name,
176                        kind = kind,
177                        input = input_str,
178                        duration_ms = duration_ms,
179                        success = success,
180                        "Function executed"
181                    );
182                } else {
183                    error!(
184                        function = function_name,
185                        kind = kind,
186                        input = input_str,
187                        duration_ms = duration_ms,
188                        success = success,
189                        error = error,
190                        "Function failed"
191                    );
192                }
193            }
194            "warn" => {
195                if success {
196                    warn!(
197                        function = function_name,
198                        kind = kind,
199                        input = input_str,
200                        duration_ms = duration_ms,
201                        success = success,
202                        "Function executed"
203                    );
204                } else {
205                    warn!(
206                        function = function_name,
207                        kind = kind,
208                        input = input_str,
209                        duration_ms = duration_ms,
210                        success = success,
211                        error = error,
212                        "Function failed"
213                    );
214                }
215            }
216            "info" => {
217                if success {
218                    info!(
219                        function = function_name,
220                        kind = kind,
221                        input = input_str,
222                        duration_ms = duration_ms,
223                        success = success,
224                        "Function executed"
225                    );
226                } else {
227                    info!(
228                        function = function_name,
229                        kind = kind,
230                        input = input_str,
231                        duration_ms = duration_ms,
232                        success = success,
233                        error = error,
234                        "Function failed"
235                    );
236                }
237            }
238            "debug" => {
239                if success {
240                    debug!(
241                        function = function_name,
242                        kind = kind,
243                        input = input_str,
244                        duration_ms = duration_ms,
245                        success = success,
246                        "Function executed"
247                    );
248                } else {
249                    debug!(
250                        function = function_name,
251                        kind = kind,
252                        input = input_str,
253                        duration_ms = duration_ms,
254                        success = success,
255                        error = error,
256                        "Function failed"
257                    );
258                }
259            }
260            // Default to trace
261            _ => {
262                if success {
263                    trace!(
264                        function = function_name,
265                        kind = kind,
266                        input = input_str,
267                        duration_ms = duration_ms,
268                        success = success,
269                        "Function executed"
270                    );
271                } else {
272                    trace!(
273                        function = function_name,
274                        kind = kind,
275                        input = input_str,
276                        duration_ms = duration_ms,
277                        success = success,
278                        error = error,
279                        "Function failed"
280                    );
281                }
282            }
283        }
284    }
285
286    /// Get the log level for a specific function.
287    fn get_function_log_level(&self, function_name: &str) -> &'static str {
288        self.registry
289            .get(function_name)
290            .and_then(|entry| entry.info().log_level)
291            .unwrap_or("trace")
292    }
293
294    /// Get the timeout for a specific function.
295    fn get_function_timeout(&self, function_name: &str) -> Duration {
296        self.registry
297            .get(function_name)
298            .and_then(|entry| entry.info().timeout)
299            .map(Duration::from_secs)
300            .unwrap_or(self.default_timeout)
301    }
302
303    /// Check if a function exists.
304    pub fn has_function(&self, function_name: &str) -> bool {
305        self.router.has_function(function_name)
306    }
307}
308
309/// Result of executing a function.
310#[derive(Debug, Clone, serde::Serialize)]
311pub struct ExecutionResult {
312    /// Function name that was executed.
313    pub function_name: String,
314    /// Kind of function (query, mutation).
315    pub function_kind: String,
316    /// The result value (or null on error).
317    pub result: Value,
318    /// Execution duration.
319    #[serde(with = "duration_millis")]
320    pub duration: Duration,
321    /// Whether execution succeeded.
322    pub success: bool,
323    /// Error message if failed.
324    pub error: Option<String>,
325}
326
327mod duration_millis {
328    use serde::{Deserialize, Deserializer, Serializer};
329    use std::time::Duration;
330
331    pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
332    where
333        S: Serializer,
334    {
335        serializer.serialize_u64(duration.as_millis() as u64)
336    }
337
338    #[allow(dead_code)]
339    pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
340    where
341        D: Deserializer<'de>,
342    {
343        let millis = u64::deserialize(deserializer)?;
344        Ok(Duration::from_millis(millis))
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    #[test]
353    fn test_execution_result_serialization() {
354        let result = ExecutionResult {
355            function_name: "get_user".to_string(),
356            function_kind: "query".to_string(),
357            result: serde_json::json!({"id": "123"}),
358            duration: Duration::from_millis(42),
359            success: true,
360            error: None,
361        };
362
363        let json = serde_json::to_string(&result).unwrap();
364        assert!(json.contains("\"duration\":42"));
365        assert!(json.contains("\"success\":true"));
366    }
367}