Skip to main content

forge_runtime/function/
executor.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{AuthContext, ForgeError, JobDispatch, RequestMetadata, Result, WorkflowDispatch};
5use serde_json::Value;
6use tokio::time::timeout;
7use tracing::{debug, error, info, trace, warn};
8
9use super::registry::FunctionRegistry;
10use super::router::{FunctionRouter, RouteResult};
11use crate::db::Database;
12
13/// Executes functions with timeout and error handling.
14pub struct FunctionExecutor {
15    router: FunctionRouter,
16    registry: Arc<FunctionRegistry>,
17    default_timeout: Duration,
18}
19
20impl FunctionExecutor {
21    /// Create a new function executor.
22    pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
23        Self {
24            router: FunctionRouter::new(Arc::clone(&registry), db),
25            registry,
26            default_timeout: Duration::from_secs(30),
27        }
28    }
29
30    /// Create a new function executor with custom timeout.
31    pub fn with_timeout(
32        registry: Arc<FunctionRegistry>,
33        db: Database,
34        default_timeout: Duration,
35    ) -> Self {
36        Self {
37            router: FunctionRouter::new(Arc::clone(&registry), db),
38            registry,
39            default_timeout,
40        }
41    }
42
43    /// Create a new function executor with dispatch capabilities.
44    pub fn with_dispatch(
45        registry: Arc<FunctionRegistry>,
46        db: Database,
47        job_dispatcher: Option<Arc<dyn JobDispatch>>,
48        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
49    ) -> Self {
50        let mut router = FunctionRouter::new(Arc::clone(&registry), db);
51        if let Some(jd) = job_dispatcher {
52            router = router.with_job_dispatcher(jd);
53        }
54        if let Some(wd) = workflow_dispatcher {
55            router = router.with_workflow_dispatcher(wd);
56        }
57        Self {
58            router,
59            registry,
60            default_timeout: Duration::from_secs(30),
61        }
62    }
63
64    /// Execute a function call.
65    pub async fn execute(
66        &self,
67        function_name: &str,
68        args: Value,
69        auth: AuthContext,
70        request: RequestMetadata,
71    ) -> Result<ExecutionResult> {
72        let start = std::time::Instant::now();
73
74        // Get function-specific timeout or use default
75        let fn_timeout = self.get_function_timeout(function_name);
76
77        // Get log level for this function (default to trace)
78        let log_level = self.get_function_log_level(function_name);
79
80        // Execute with timeout
81        let result = match timeout(
82            fn_timeout,
83            self.router
84                .route(function_name, args.clone(), auth, request),
85        )
86        .await
87        {
88            Ok(result) => result,
89            Err(_) => {
90                let duration = start.elapsed();
91                self.log_execution(
92                    log_level,
93                    function_name,
94                    "unknown",
95                    &args,
96                    duration,
97                    false,
98                    Some(&format!("Timeout after {:?}", fn_timeout)),
99                );
100                return Err(ForgeError::Timeout(format!(
101                    "Function '{}' timed out after {:?}",
102                    function_name, fn_timeout
103                )));
104            }
105        };
106
107        let duration = start.elapsed();
108
109        match result {
110            Ok(route_result) => {
111                let (kind, value) = match route_result {
112                    RouteResult::Query(v) => ("query", v),
113                    RouteResult::Mutation(v) => ("mutation", v),
114                    RouteResult::Job(v) => ("job", v),
115                    RouteResult::Workflow(v) => ("workflow", v),
116                };
117
118                self.log_execution(log_level, function_name, kind, &args, duration, true, None);
119
120                Ok(ExecutionResult {
121                    function_name: function_name.to_string(),
122                    function_kind: kind.to_string(),
123                    result: value,
124                    duration,
125                    success: true,
126                    error: None,
127                })
128            }
129            Err(e) => {
130                let kind = self
131                    .router
132                    .get_function_kind(function_name)
133                    .map(|k| k.to_string())
134                    .unwrap_or_else(|| "unknown".to_string());
135
136                self.log_execution(
137                    log_level,
138                    function_name,
139                    &kind,
140                    &args,
141                    duration,
142                    false,
143                    Some(&e.to_string()),
144                );
145
146                Err(e)
147            }
148        }
149    }
150
151    /// Log function execution at the configured level.
152    #[allow(clippy::too_many_arguments)]
153    fn log_execution(
154        &self,
155        log_level: &str,
156        function_name: &str,
157        kind: &str,
158        input: &Value,
159        duration: Duration,
160        success: bool,
161        error: Option<&str>,
162    ) {
163        macro_rules! log_fn {
164            ($level:ident) => {
165                if success {
166                    $level!(
167                        function = function_name,
168                        kind = kind,
169                        input = %input,
170                        duration_ms = duration.as_millis() as u64,
171                        success = success,
172                        "Function executed"
173                    );
174                } else {
175                    $level!(
176                        function = function_name,
177                        kind = kind,
178                        input = %input,
179                        duration_ms = duration.as_millis() as u64,
180                        success = success,
181                        error = error,
182                        "Function failed"
183                    );
184                }
185            };
186        }
187
188        match log_level {
189            "off" => {}
190            "error" => log_fn!(error),
191            "warn" => log_fn!(warn),
192            "info" => log_fn!(info),
193            "debug" => log_fn!(debug),
194            _ => log_fn!(trace),
195        }
196    }
197
198    /// Get the log level for a specific function.
199    fn get_function_log_level(&self, function_name: &str) -> &'static str {
200        self.registry
201            .get(function_name)
202            .and_then(|entry| entry.info().log_level)
203            .unwrap_or("trace")
204    }
205
206    /// Get the timeout for a specific function.
207    fn get_function_timeout(&self, function_name: &str) -> Duration {
208        self.registry
209            .get(function_name)
210            .and_then(|entry| entry.info().timeout)
211            .map(Duration::from_secs)
212            .unwrap_or(self.default_timeout)
213    }
214
215    /// Check if a function exists.
216    pub fn has_function(&self, function_name: &str) -> bool {
217        self.router.has_function(function_name)
218    }
219}
220
221/// Result of executing a function.
222#[derive(Debug, Clone, serde::Serialize)]
223pub struct ExecutionResult {
224    /// Function name that was executed.
225    pub function_name: String,
226    /// Kind of function (query, mutation).
227    pub function_kind: String,
228    /// The result value (or null on error).
229    pub result: Value,
230    /// Execution duration.
231    #[serde(with = "duration_millis")]
232    pub duration: Duration,
233    /// Whether execution succeeded.
234    pub success: bool,
235    /// Error message if failed.
236    pub error: Option<String>,
237}
238
239mod duration_millis {
240    use serde::{Deserialize, Deserializer, Serializer};
241    use std::time::Duration;
242
243    pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
244    where
245        S: Serializer,
246    {
247        serializer.serialize_u64(duration.as_millis() as u64)
248    }
249
250    #[allow(dead_code)]
251    pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
252    where
253        D: Deserializer<'de>,
254    {
255        let millis = u64::deserialize(deserializer)?;
256        Ok(Duration::from_millis(millis))
257    }
258}
259
260#[cfg(test)]
261#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
262mod tests {
263    use super::*;
264
265    #[test]
266    fn test_execution_result_serialization() {
267        let result = ExecutionResult {
268            function_name: "get_user".to_string(),
269            function_kind: "query".to_string(),
270            result: serde_json::json!({"id": "123"}),
271            duration: Duration::from_millis(42),
272            success: true,
273            error: None,
274        };
275
276        let json = serde_json::to_string(&result).unwrap();
277        assert!(json.contains("\"duration\":42"));
278        assert!(json.contains("\"success\":true"));
279    }
280}