Skip to main content

forge_runtime/function/
executor.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{AuthContext, ForgeError, JobDispatch, RequestMetadata, Result, WorkflowDispatch};
5use serde_json::Value;
6use tokio::time::timeout;
7use tracing::{Instrument, debug, error, info, trace, warn};
8
9use super::registry::FunctionRegistry;
10use super::router::{FunctionRouter, RouteResult};
11use crate::db::Database;
12
13/// Executes functions with timeout and error handling.
14pub struct FunctionExecutor {
15    router: FunctionRouter,
16    registry: Arc<FunctionRegistry>,
17    default_timeout: Duration,
18}
19
20impl FunctionExecutor {
21    /// Create a new function executor.
22    pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
23        Self {
24            router: FunctionRouter::new(Arc::clone(&registry), db),
25            registry,
26            default_timeout: Duration::from_secs(30),
27        }
28    }
29
30    /// Create a new function executor with custom timeout.
31    pub fn with_timeout(
32        registry: Arc<FunctionRegistry>,
33        db: Database,
34        default_timeout: Duration,
35    ) -> Self {
36        Self {
37            router: FunctionRouter::new(Arc::clone(&registry), db),
38            registry,
39            default_timeout,
40        }
41    }
42
43    /// Create a new function executor with dispatch capabilities.
44    pub fn with_dispatch(
45        registry: Arc<FunctionRegistry>,
46        db: Database,
47        job_dispatcher: Option<Arc<dyn JobDispatch>>,
48        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
49    ) -> Self {
50        let mut router = FunctionRouter::new(Arc::clone(&registry), db);
51        if let Some(jd) = job_dispatcher {
52            router = router.with_job_dispatcher(jd);
53        }
54        if let Some(wd) = workflow_dispatcher {
55            router = router.with_workflow_dispatcher(wd);
56        }
57        Self {
58            router,
59            registry,
60            default_timeout: Duration::from_secs(30),
61        }
62    }
63
64    /// Execute a function call.
65    pub async fn execute(
66        &self,
67        function_name: &str,
68        args: Value,
69        auth: AuthContext,
70        request: RequestMetadata,
71    ) -> Result<ExecutionResult> {
72        let start = std::time::Instant::now();
73        let fn_timeout = self.get_function_timeout(function_name);
74        let log_level = self.get_function_log_level(function_name);
75
76        let kind = self
77            .router
78            .get_function_kind(function_name)
79            .map(|k| k.to_string())
80            .unwrap_or_else(|| "unknown".to_string());
81
82        let span = tracing::info_span!(
83            "fn.execute",
84            function = function_name,
85            kind = %kind,
86        );
87
88        let result = match timeout(
89            fn_timeout,
90            self.router
91                .route(function_name, args.clone(), auth, request)
92                .instrument(span),
93        )
94        .await
95        {
96            Ok(result) => result,
97            Err(_) => {
98                let duration = start.elapsed();
99                self.log_execution(
100                    log_level,
101                    function_name,
102                    "unknown",
103                    &args,
104                    duration,
105                    false,
106                    Some(&format!("Timeout after {:?}", fn_timeout)),
107                );
108                return Err(ForgeError::Timeout(format!(
109                    "Function '{}' timed out after {:?}",
110                    function_name, fn_timeout
111                )));
112            }
113        };
114
115        let duration = start.elapsed();
116
117        match result {
118            Ok(route_result) => {
119                let (result_kind, value) = match route_result {
120                    RouteResult::Query(v) => ("query", v),
121                    RouteResult::Mutation(v) => ("mutation", v),
122                    RouteResult::Job(v) => ("job", v),
123                    RouteResult::Workflow(v) => ("workflow", v),
124                };
125
126                self.log_execution(
127                    log_level,
128                    function_name,
129                    result_kind,
130                    &args,
131                    duration,
132                    true,
133                    None,
134                );
135
136                Ok(ExecutionResult {
137                    function_name: function_name.to_string(),
138                    function_kind: result_kind.to_string(),
139                    result: value,
140                    duration,
141                    success: true,
142                    error: None,
143                })
144            }
145            Err(e) => {
146                self.log_execution(
147                    log_level,
148                    function_name,
149                    &kind,
150                    &args,
151                    duration,
152                    false,
153                    Some(&e.to_string()),
154                );
155
156                Err(e)
157            }
158        }
159    }
160
161    /// Log function execution at the configured level.
162    #[allow(clippy::too_many_arguments)]
163    fn log_execution(
164        &self,
165        log_level: &str,
166        function_name: &str,
167        kind: &str,
168        input: &Value,
169        duration: Duration,
170        success: bool,
171        error: Option<&str>,
172    ) {
173        macro_rules! log_fn {
174            ($level:ident) => {
175                if success {
176                    $level!(
177                        function = function_name,
178                        kind = kind,
179                        input = %input,
180                        duration_ms = duration.as_millis() as u64,
181                        success = success,
182                        "Function executed"
183                    );
184                } else {
185                    $level!(
186                        function = function_name,
187                        kind = kind,
188                        input = %input,
189                        duration_ms = duration.as_millis() as u64,
190                        success = success,
191                        error = error,
192                        "Function failed"
193                    );
194                }
195            };
196        }
197
198        match log_level {
199            "off" => {}
200            "error" => log_fn!(error),
201            "warn" => log_fn!(warn),
202            "info" => log_fn!(info),
203            "debug" => log_fn!(debug),
204            _ => log_fn!(trace),
205        }
206    }
207
208    /// Mutations default to "info" because writes are worth tracking.
209    /// Queries default to "debug" since they're high-volume.
210    fn get_function_log_level(&self, function_name: &str) -> &'static str {
211        self.registry
212            .get(function_name)
213            .map(|entry| {
214                entry.info().log_level.unwrap_or(match entry.kind() {
215                    forge_core::FunctionKind::Mutation => "info",
216                    forge_core::FunctionKind::Query => "debug",
217                })
218            })
219            .unwrap_or("info")
220    }
221
222    /// Get the timeout for a specific function.
223    fn get_function_timeout(&self, function_name: &str) -> Duration {
224        self.registry
225            .get(function_name)
226            .and_then(|entry| entry.info().timeout)
227            .map(Duration::from_secs)
228            .unwrap_or(self.default_timeout)
229    }
230
231    /// Check if a function exists.
232    pub fn has_function(&self, function_name: &str) -> bool {
233        self.router.has_function(function_name)
234    }
235}
236
237/// Result of executing a function.
238#[derive(Debug, Clone, serde::Serialize)]
239pub struct ExecutionResult {
240    /// Function name that was executed.
241    pub function_name: String,
242    /// Kind of function (query, mutation).
243    pub function_kind: String,
244    /// The result value (or null on error).
245    pub result: Value,
246    /// Execution duration.
247    #[serde(with = "duration_millis")]
248    pub duration: Duration,
249    /// Whether execution succeeded.
250    pub success: bool,
251    /// Error message if failed.
252    pub error: Option<String>,
253}
254
255mod duration_millis {
256    use serde::{Deserialize, Deserializer, Serializer};
257    use std::time::Duration;
258
259    pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
260    where
261        S: Serializer,
262    {
263        serializer.serialize_u64(duration.as_millis() as u64)
264    }
265
266    #[allow(dead_code)]
267    pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
268    where
269        D: Deserializer<'de>,
270    {
271        let millis = u64::deserialize(deserializer)?;
272        Ok(Duration::from_millis(millis))
273    }
274}
275
276#[cfg(test)]
277#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
278mod tests {
279    use super::*;
280
281    #[test]
282    fn test_execution_result_serialization() {
283        let result = ExecutionResult {
284            function_name: "get_user".to_string(),
285            function_kind: "query".to_string(),
286            result: serde_json::json!({"id": "123"}),
287            duration: Duration::from_millis(42),
288            success: true,
289            error: None,
290        };
291
292        let json = serde_json::to_string(&result).unwrap();
293        assert!(json.contains("\"duration\":42"));
294        assert!(json.contains("\"success\":true"));
295    }
296}