forge_runtime/function/
executor.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{AuthContext, ForgeError, JobDispatch, RequestMetadata, Result, WorkflowDispatch};
5use serde_json::Value;
6use tokio::time::timeout;
7use tracing::{debug, error, info, trace, warn};
8
9use super::registry::FunctionRegistry;
10use super::router::{FunctionRouter, RouteResult};
11
12/// Executes functions with timeout and error handling.
13pub struct FunctionExecutor {
14    router: FunctionRouter,
15    registry: Arc<FunctionRegistry>,
16    default_timeout: Duration,
17}
18
19impl FunctionExecutor {
20    /// Create a new function executor.
21    pub fn new(registry: Arc<FunctionRegistry>, db_pool: sqlx::PgPool) -> Self {
22        Self {
23            router: FunctionRouter::new(Arc::clone(&registry), db_pool),
24            registry,
25            default_timeout: Duration::from_secs(30),
26        }
27    }
28
29    /// Create a new function executor with custom timeout.
30    pub fn with_timeout(
31        registry: Arc<FunctionRegistry>,
32        db_pool: sqlx::PgPool,
33        default_timeout: Duration,
34    ) -> Self {
35        Self {
36            router: FunctionRouter::new(Arc::clone(&registry), db_pool),
37            registry,
38            default_timeout,
39        }
40    }
41
42    /// Create a new function executor with dispatch capabilities.
43    pub fn with_dispatch(
44        registry: Arc<FunctionRegistry>,
45        db_pool: sqlx::PgPool,
46        job_dispatcher: Option<Arc<dyn JobDispatch>>,
47        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
48    ) -> Self {
49        let mut router = FunctionRouter::new(Arc::clone(&registry), db_pool);
50        if let Some(jd) = job_dispatcher {
51            router = router.with_job_dispatcher(jd);
52        }
53        if let Some(wd) = workflow_dispatcher {
54            router = router.with_workflow_dispatcher(wd);
55        }
56        Self {
57            router,
58            registry,
59            default_timeout: Duration::from_secs(30),
60        }
61    }
62
63    /// Execute a function call.
64    pub async fn execute(
65        &self,
66        function_name: &str,
67        args: Value,
68        auth: AuthContext,
69        request: RequestMetadata,
70    ) -> Result<ExecutionResult> {
71        let start = std::time::Instant::now();
72
73        // Get function-specific timeout or use default
74        let fn_timeout = self.get_function_timeout(function_name);
75
76        // Get log level for this function (default to trace)
77        let log_level = self.get_function_log_level(function_name);
78
79        // Execute with timeout
80        let result = match timeout(
81            fn_timeout,
82            self.router
83                .route(function_name, args.clone(), auth, request),
84        )
85        .await
86        {
87            Ok(result) => result,
88            Err(_) => {
89                let duration = start.elapsed();
90                self.log_execution(
91                    log_level,
92                    function_name,
93                    "unknown",
94                    &args,
95                    duration,
96                    false,
97                    Some(&format!("Timeout after {:?}", fn_timeout)),
98                );
99                return Err(ForgeError::Timeout(format!(
100                    "Function '{}' timed out after {:?}",
101                    function_name, fn_timeout
102                )));
103            }
104        };
105
106        let duration = start.elapsed();
107
108        match result {
109            Ok(route_result) => {
110                let (kind, value) = match route_result {
111                    RouteResult::Query(v) => ("query", v),
112                    RouteResult::Mutation(v) => ("mutation", v),
113                    RouteResult::Action(v) => ("action", v),
114                };
115
116                self.log_execution(log_level, function_name, kind, &args, duration, true, None);
117
118                Ok(ExecutionResult {
119                    function_name: function_name.to_string(),
120                    function_kind: kind.to_string(),
121                    result: value,
122                    duration,
123                    success: true,
124                    error: None,
125                })
126            }
127            Err(e) => {
128                let kind = self
129                    .router
130                    .get_function_kind(function_name)
131                    .map(|k| k.to_string())
132                    .unwrap_or_else(|| "unknown".to_string());
133
134                self.log_execution(
135                    log_level,
136                    function_name,
137                    &kind,
138                    &args,
139                    duration,
140                    false,
141                    Some(&e.to_string()),
142                );
143
144                Ok(ExecutionResult {
145                    function_name: function_name.to_string(),
146                    function_kind: kind,
147                    result: Value::Null,
148                    duration,
149                    success: false,
150                    error: Some(e.to_string()),
151                })
152            }
153        }
154    }
155
156    /// Log function execution at the configured level.
157    #[allow(clippy::too_many_arguments)]
158    fn log_execution(
159        &self,
160        log_level: &str,
161        function_name: &str,
162        kind: &str,
163        input: &Value,
164        duration: Duration,
165        success: bool,
166        error: Option<&str>,
167    ) {
168        let duration_ms = duration.as_millis();
169        let input_str = input.to_string();
170
171        match log_level {
172            "off" => {}
173            "error" => {
174                if success {
175                    error!(
176                        function = function_name,
177                        kind = kind,
178                        input = input_str,
179                        duration_ms = duration_ms,
180                        success = success,
181                        "Function executed"
182                    );
183                } else {
184                    error!(
185                        function = function_name,
186                        kind = kind,
187                        input = input_str,
188                        duration_ms = duration_ms,
189                        success = success,
190                        error = error,
191                        "Function failed"
192                    );
193                }
194            }
195            "warn" => {
196                if success {
197                    warn!(
198                        function = function_name,
199                        kind = kind,
200                        input = input_str,
201                        duration_ms = duration_ms,
202                        success = success,
203                        "Function executed"
204                    );
205                } else {
206                    warn!(
207                        function = function_name,
208                        kind = kind,
209                        input = input_str,
210                        duration_ms = duration_ms,
211                        success = success,
212                        error = error,
213                        "Function failed"
214                    );
215                }
216            }
217            "info" => {
218                if success {
219                    info!(
220                        function = function_name,
221                        kind = kind,
222                        input = input_str,
223                        duration_ms = duration_ms,
224                        success = success,
225                        "Function executed"
226                    );
227                } else {
228                    info!(
229                        function = function_name,
230                        kind = kind,
231                        input = input_str,
232                        duration_ms = duration_ms,
233                        success = success,
234                        error = error,
235                        "Function failed"
236                    );
237                }
238            }
239            "debug" => {
240                if success {
241                    debug!(
242                        function = function_name,
243                        kind = kind,
244                        input = input_str,
245                        duration_ms = duration_ms,
246                        success = success,
247                        "Function executed"
248                    );
249                } else {
250                    debug!(
251                        function = function_name,
252                        kind = kind,
253                        input = input_str,
254                        duration_ms = duration_ms,
255                        success = success,
256                        error = error,
257                        "Function failed"
258                    );
259                }
260            }
261            // Default to trace
262            _ => {
263                if success {
264                    trace!(
265                        function = function_name,
266                        kind = kind,
267                        input = input_str,
268                        duration_ms = duration_ms,
269                        success = success,
270                        "Function executed"
271                    );
272                } else {
273                    trace!(
274                        function = function_name,
275                        kind = kind,
276                        input = input_str,
277                        duration_ms = duration_ms,
278                        success = success,
279                        error = error,
280                        "Function failed"
281                    );
282                }
283            }
284        }
285    }
286
287    /// Get the log level for a specific function.
288    fn get_function_log_level(&self, function_name: &str) -> &'static str {
289        self.registry
290            .get(function_name)
291            .and_then(|entry| entry.info().log_level)
292            .unwrap_or("trace")
293    }
294
295    /// Get the timeout for a specific function.
296    fn get_function_timeout(&self, function_name: &str) -> Duration {
297        self.registry
298            .get(function_name)
299            .and_then(|entry| entry.info().timeout)
300            .map(Duration::from_secs)
301            .unwrap_or(self.default_timeout)
302    }
303
304    /// Check if a function exists.
305    pub fn has_function(&self, function_name: &str) -> bool {
306        self.router.has_function(function_name)
307    }
308}
309
310/// Result of executing a function.
311#[derive(Debug, Clone, serde::Serialize)]
312pub struct ExecutionResult {
313    /// Function name that was executed.
314    pub function_name: String,
315    /// Kind of function (query, mutation, action).
316    pub function_kind: String,
317    /// The result value (or null on error).
318    pub result: Value,
319    /// Execution duration.
320    #[serde(with = "duration_millis")]
321    pub duration: Duration,
322    /// Whether execution succeeded.
323    pub success: bool,
324    /// Error message if failed.
325    pub error: Option<String>,
326}
327
328mod duration_millis {
329    use serde::{Deserialize, Deserializer, Serializer};
330    use std::time::Duration;
331
332    pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
333    where
334        S: Serializer,
335    {
336        serializer.serialize_u64(duration.as_millis() as u64)
337    }
338
339    #[allow(dead_code)]
340    pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
341    where
342        D: Deserializer<'de>,
343    {
344        let millis = u64::deserialize(deserializer)?;
345        Ok(Duration::from_millis(millis))
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_execution_result_serialization() {
355        let result = ExecutionResult {
356            function_name: "get_user".to_string(),
357            function_kind: "query".to_string(),
358            result: serde_json::json!({"id": "123"}),
359            duration: Duration::from_millis(42),
360            success: true,
361            error: None,
362        };
363
364        let json = serde_json::to_string(&result).unwrap();
365        assert!(json.contains("\"duration\":42"));
366        assert!(json.contains("\"success\":true"));
367    }
368}