Skip to main content

forge_runtime/function/
executor.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{AuthContext, ForgeError, JobDispatch, RequestMetadata, Result, WorkflowDispatch};
5use serde_json::Value;
6use tokio::time::timeout;
7use tracing::{debug, error, info, trace, warn};
8
9use super::registry::FunctionRegistry;
10use super::router::{FunctionRouter, RouteResult};
11
12/// Executes functions with timeout and error handling.
13pub struct FunctionExecutor {
14    router: FunctionRouter,
15    registry: Arc<FunctionRegistry>,
16    default_timeout: Duration,
17}
18
19impl FunctionExecutor {
20    /// Create a new function executor.
21    pub fn new(registry: Arc<FunctionRegistry>, db_pool: sqlx::PgPool) -> Self {
22        Self {
23            router: FunctionRouter::new(Arc::clone(&registry), db_pool),
24            registry,
25            default_timeout: Duration::from_secs(30),
26        }
27    }
28
29    /// Create a new function executor with custom timeout.
30    pub fn with_timeout(
31        registry: Arc<FunctionRegistry>,
32        db_pool: sqlx::PgPool,
33        default_timeout: Duration,
34    ) -> Self {
35        Self {
36            router: FunctionRouter::new(Arc::clone(&registry), db_pool),
37            registry,
38            default_timeout,
39        }
40    }
41
42    /// Create a new function executor with dispatch capabilities.
43    pub fn with_dispatch(
44        registry: Arc<FunctionRegistry>,
45        db_pool: sqlx::PgPool,
46        job_dispatcher: Option<Arc<dyn JobDispatch>>,
47        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
48    ) -> Self {
49        let mut router = FunctionRouter::new(Arc::clone(&registry), db_pool);
50        if let Some(jd) = job_dispatcher {
51            router = router.with_job_dispatcher(jd);
52        }
53        if let Some(wd) = workflow_dispatcher {
54            router = router.with_workflow_dispatcher(wd);
55        }
56        Self {
57            router,
58            registry,
59            default_timeout: Duration::from_secs(30),
60        }
61    }
62
63    /// Execute a function call.
64    pub async fn execute(
65        &self,
66        function_name: &str,
67        args: Value,
68        auth: AuthContext,
69        request: RequestMetadata,
70    ) -> Result<ExecutionResult> {
71        let start = std::time::Instant::now();
72
73        // Get function-specific timeout or use default
74        let fn_timeout = self.get_function_timeout(function_name);
75
76        // Get log level for this function (default to trace)
77        let log_level = self.get_function_log_level(function_name);
78
79        // Execute with timeout
80        let result = match timeout(
81            fn_timeout,
82            self.router
83                .route(function_name, args.clone(), auth, request),
84        )
85        .await
86        {
87            Ok(result) => result,
88            Err(_) => {
89                let duration = start.elapsed();
90                self.log_execution(
91                    log_level,
92                    function_name,
93                    "unknown",
94                    &args,
95                    duration,
96                    false,
97                    Some(&format!("Timeout after {:?}", fn_timeout)),
98                );
99                return Err(ForgeError::Timeout(format!(
100                    "Function '{}' timed out after {:?}",
101                    function_name, fn_timeout
102                )));
103            }
104        };
105
106        let duration = start.elapsed();
107
108        match result {
109            Ok(route_result) => {
110                let (kind, value) = match route_result {
111                    RouteResult::Query(v) => ("query", v),
112                    RouteResult::Mutation(v) => ("mutation", v),
113                    RouteResult::Job(v) => ("job", v),
114                    RouteResult::Workflow(v) => ("workflow", v),
115                };
116
117                self.log_execution(log_level, function_name, kind, &args, duration, true, None);
118
119                Ok(ExecutionResult {
120                    function_name: function_name.to_string(),
121                    function_kind: kind.to_string(),
122                    result: value,
123                    duration,
124                    success: true,
125                    error: None,
126                })
127            }
128            Err(e) => {
129                let kind = self
130                    .router
131                    .get_function_kind(function_name)
132                    .map(|k| k.to_string())
133                    .unwrap_or_else(|| "unknown".to_string());
134
135                self.log_execution(
136                    log_level,
137                    function_name,
138                    &kind,
139                    &args,
140                    duration,
141                    false,
142                    Some(&e.to_string()),
143                );
144
145                Err(e)
146            }
147        }
148    }
149
150    /// Log function execution at the configured level.
151    #[allow(clippy::too_many_arguments)]
152    fn log_execution(
153        &self,
154        log_level: &str,
155        function_name: &str,
156        kind: &str,
157        input: &Value,
158        duration: Duration,
159        success: bool,
160        error: Option<&str>,
161    ) {
162        macro_rules! log_fn {
163            ($level:ident) => {
164                if success {
165                    $level!(
166                        function = function_name,
167                        kind = kind,
168                        input = %input,
169                        duration_ms = duration.as_millis() as u64,
170                        success = success,
171                        "Function executed"
172                    );
173                } else {
174                    $level!(
175                        function = function_name,
176                        kind = kind,
177                        input = %input,
178                        duration_ms = duration.as_millis() as u64,
179                        success = success,
180                        error = error,
181                        "Function failed"
182                    );
183                }
184            };
185        }
186
187        match log_level {
188            "off" => {}
189            "error" => log_fn!(error),
190            "warn" => log_fn!(warn),
191            "info" => log_fn!(info),
192            "debug" => log_fn!(debug),
193            _ => log_fn!(trace),
194        }
195    }
196
197    /// Get the log level for a specific function.
198    fn get_function_log_level(&self, function_name: &str) -> &'static str {
199        self.registry
200            .get(function_name)
201            .and_then(|entry| entry.info().log_level)
202            .unwrap_or("trace")
203    }
204
205    /// Get the timeout for a specific function.
206    fn get_function_timeout(&self, function_name: &str) -> Duration {
207        self.registry
208            .get(function_name)
209            .and_then(|entry| entry.info().timeout)
210            .map(Duration::from_secs)
211            .unwrap_or(self.default_timeout)
212    }
213
214    /// Check if a function exists.
215    pub fn has_function(&self, function_name: &str) -> bool {
216        self.router.has_function(function_name)
217    }
218}
219
220/// Result of executing a function.
221#[derive(Debug, Clone, serde::Serialize)]
222pub struct ExecutionResult {
223    /// Function name that was executed.
224    pub function_name: String,
225    /// Kind of function (query, mutation).
226    pub function_kind: String,
227    /// The result value (or null on error).
228    pub result: Value,
229    /// Execution duration.
230    #[serde(with = "duration_millis")]
231    pub duration: Duration,
232    /// Whether execution succeeded.
233    pub success: bool,
234    /// Error message if failed.
235    pub error: Option<String>,
236}
237
238mod duration_millis {
239    use serde::{Deserialize, Deserializer, Serializer};
240    use std::time::Duration;
241
242    pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
243    where
244        S: Serializer,
245    {
246        serializer.serialize_u64(duration.as_millis() as u64)
247    }
248
249    #[allow(dead_code)]
250    pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
251    where
252        D: Deserializer<'de>,
253    {
254        let millis = u64::deserialize(deserializer)?;
255        Ok(Duration::from_millis(millis))
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    #[test]
264    fn test_execution_result_serialization() {
265        let result = ExecutionResult {
266            function_name: "get_user".to_string(),
267            function_kind: "query".to_string(),
268            result: serde_json::json!({"id": "123"}),
269            duration: Duration::from_millis(42),
270            success: true,
271            error: None,
272        };
273
274        let json = serde_json::to_string(&result).unwrap();
275        assert!(json.contains("\"duration\":42"));
276        assert!(json.contains("\"success\":true"));
277    }
278}