Skip to main content

forge_runtime/function/
executor.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{AuthContext, ForgeError, JobDispatch, RequestMetadata, Result, WorkflowDispatch};
5use serde_json::Value;
6use tokio::time::timeout;
7use tracing::{Instrument, debug, error, info, trace, warn};
8
9use super::registry::FunctionRegistry;
10use super::router::{FunctionRouter, RouteResult};
11use crate::db::Database;
12
13/// Executes functions with timeout and error handling.
14pub struct FunctionExecutor {
15    router: FunctionRouter,
16    registry: Arc<FunctionRegistry>,
17    default_timeout: Duration,
18}
19
20impl FunctionExecutor {
21    /// Create a new function executor.
22    pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
23        Self {
24            router: FunctionRouter::new(Arc::clone(&registry), db),
25            registry,
26            default_timeout: Duration::from_secs(30),
27        }
28    }
29
30    /// Create a new function executor with custom timeout.
31    pub fn with_timeout(
32        registry: Arc<FunctionRegistry>,
33        db: Database,
34        default_timeout: Duration,
35    ) -> Self {
36        Self {
37            router: FunctionRouter::new(Arc::clone(&registry), db),
38            registry,
39            default_timeout,
40        }
41    }
42
43    /// Create a new function executor with dispatch capabilities.
44    pub fn with_dispatch(
45        registry: Arc<FunctionRegistry>,
46        db: Database,
47        job_dispatcher: Option<Arc<dyn JobDispatch>>,
48        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
49    ) -> Self {
50        let mut router = FunctionRouter::new(Arc::clone(&registry), db);
51        if let Some(jd) = job_dispatcher {
52            router = router.with_job_dispatcher(jd);
53        }
54        if let Some(wd) = workflow_dispatcher {
55            router = router.with_workflow_dispatcher(wd);
56        }
57        Self {
58            router,
59            registry,
60            default_timeout: Duration::from_secs(30),
61        }
62    }
63
64    /// Execute a function call.
65    pub async fn execute(
66        &self,
67        function_name: &str,
68        args: Value,
69        auth: AuthContext,
70        request: RequestMetadata,
71    ) -> Result<ExecutionResult> {
72        let start = std::time::Instant::now();
73        let fn_timeout = self.get_function_timeout(function_name);
74        let log_level = self.get_function_log_level(function_name);
75
76        let kind = self
77            .router
78            .get_function_kind(function_name)
79            .map(|k| k.to_string())
80            .unwrap_or_else(|| "unknown".to_string());
81
82        let span = tracing::info_span!(
83            "fn.execute",
84            function = function_name,
85            kind = %kind,
86        );
87
88        let result = match timeout(
89            fn_timeout,
90            self.router
91                .route(function_name, args.clone(), auth, request)
92                .instrument(span),
93        )
94        .await
95        {
96            Ok(result) => result,
97            Err(_) => {
98                let duration = start.elapsed();
99                self.log_execution(
100                    log_level,
101                    function_name,
102                    "unknown",
103                    &args,
104                    duration,
105                    false,
106                    Some(&format!("Timeout after {:?}", fn_timeout)),
107                );
108                crate::observability::record_fn_execution(
109                    function_name,
110                    &kind,
111                    false,
112                    duration.as_secs_f64(),
113                );
114                return Err(ForgeError::Timeout(format!(
115                    "Function '{}' timed out after {:?}",
116                    function_name, fn_timeout
117                )));
118            }
119        };
120
121        let duration = start.elapsed();
122
123        match result {
124            Ok(route_result) => {
125                let (result_kind, value) = match route_result {
126                    RouteResult::Query(v) => ("query", v),
127                    RouteResult::Mutation(v) => ("mutation", v),
128                    RouteResult::Job(v) => ("job", v),
129                    RouteResult::Workflow(v) => ("workflow", v),
130                };
131
132                self.log_execution(
133                    log_level,
134                    function_name,
135                    result_kind,
136                    &args,
137                    duration,
138                    true,
139                    None,
140                );
141                crate::observability::record_fn_execution(
142                    function_name,
143                    result_kind,
144                    true,
145                    duration.as_secs_f64(),
146                );
147
148                Ok(ExecutionResult {
149                    function_name: function_name.to_string(),
150                    function_kind: result_kind.to_string(),
151                    result: value,
152                    duration,
153                    success: true,
154                    error: None,
155                })
156            }
157            Err(e) => {
158                self.log_execution(
159                    log_level,
160                    function_name,
161                    &kind,
162                    &args,
163                    duration,
164                    false,
165                    Some(&e.to_string()),
166                );
167                crate::observability::record_fn_execution(
168                    function_name,
169                    &kind,
170                    false,
171                    duration.as_secs_f64(),
172                );
173
174                Err(e)
175            }
176        }
177    }
178
179    /// Log function execution at the configured level.
180    #[allow(clippy::too_many_arguments)]
181    fn log_execution(
182        &self,
183        log_level: &str,
184        function_name: &str,
185        kind: &str,
186        input: &Value,
187        duration: Duration,
188        success: bool,
189        error: Option<&str>,
190    ) {
191        macro_rules! log_fn {
192            ($level:ident) => {
193                if success {
194                    $level!(
195                        function = function_name,
196                        kind = kind,
197                        duration_ms = duration.as_millis() as u64,
198                        "Function executed"
199                    );
200                    debug!(
201                        function = function_name,
202                        input = %input,
203                        "Function input"
204                    );
205                } else {
206                    $level!(
207                        function = function_name,
208                        kind = kind,
209                        duration_ms = duration.as_millis() as u64,
210                        error = error,
211                        "Function failed"
212                    );
213                    debug!(
214                        function = function_name,
215                        input = %input,
216                        "Function input"
217                    );
218                }
219            };
220        }
221
222        match log_level {
223            "off" => {}
224            "error" => log_fn!(error),
225            "warn" => log_fn!(warn),
226            "info" => log_fn!(info),
227            "debug" => log_fn!(debug),
228            _ => log_fn!(trace),
229        }
230    }
231
232    /// Mutations default to "info" because writes are worth tracking.
233    /// Queries default to "debug" since they're high-volume.
234    fn get_function_log_level(&self, function_name: &str) -> &'static str {
235        self.registry
236            .get(function_name)
237            .map(|entry| {
238                entry.info().log_level.unwrap_or(match entry.kind() {
239                    forge_core::FunctionKind::Mutation => "info",
240                    forge_core::FunctionKind::Query => "debug",
241                })
242            })
243            .unwrap_or("info")
244    }
245
246    /// Get the timeout for a specific function.
247    fn get_function_timeout(&self, function_name: &str) -> Duration {
248        self.registry
249            .get(function_name)
250            .and_then(|entry| entry.info().timeout)
251            .map(Duration::from_secs)
252            .unwrap_or(self.default_timeout)
253    }
254
255    /// Check if a function exists.
256    pub fn has_function(&self, function_name: &str) -> bool {
257        self.router.has_function(function_name)
258    }
259}
260
261/// Result of executing a function.
262#[derive(Debug, Clone, serde::Serialize)]
263pub struct ExecutionResult {
264    /// Function name that was executed.
265    pub function_name: String,
266    /// Kind of function (query, mutation).
267    pub function_kind: String,
268    /// The result value (or null on error).
269    pub result: Value,
270    /// Execution duration.
271    #[serde(with = "duration_millis")]
272    pub duration: Duration,
273    /// Whether execution succeeded.
274    pub success: bool,
275    /// Error message if failed.
276    pub error: Option<String>,
277}
278
279mod duration_millis {
280    use serde::{Deserialize, Deserializer, Serializer};
281    use std::time::Duration;
282
283    pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
284    where
285        S: Serializer,
286    {
287        serializer.serialize_u64(duration.as_millis() as u64)
288    }
289
290    #[allow(dead_code)]
291    pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
292    where
293        D: Deserializer<'de>,
294    {
295        let millis = u64::deserialize(deserializer)?;
296        Ok(Duration::from_millis(millis))
297    }
298}
299
300#[cfg(test)]
301#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
302mod tests {
303    use super::*;
304
305    #[test]
306    fn test_execution_result_serialization() {
307        let result = ExecutionResult {
308            function_name: "get_user".to_string(),
309            function_kind: "query".to_string(),
310            result: serde_json::json!({"id": "123"}),
311            duration: Duration::from_millis(42),
312            success: true,
313            error: None,
314        };
315
316        let json = serde_json::to_string(&result).unwrap();
317        assert!(json.contains("\"duration\":42"));
318        assert!(json.contains("\"success\":true"));
319    }
320}