Skip to main content

forge_runtime/function/
executor.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{AuthContext, ForgeError, JobDispatch, RequestMetadata, Result, WorkflowDispatch};
5use serde_json::Value;
6use tokio::time::timeout;
7use tracing::{Instrument, debug, error, info, trace, warn};
8
9use super::registry::FunctionRegistry;
10use super::router::{FunctionRouter, RouteResult};
11use crate::db::Database;
12
13/// Executes functions with timeout and error handling.
14pub struct FunctionExecutor {
15    router: FunctionRouter,
16    registry: Arc<FunctionRegistry>,
17    default_timeout: Duration,
18}
19
20impl FunctionExecutor {
21    /// Create a new function executor.
22    pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
23        Self {
24            router: FunctionRouter::new(Arc::clone(&registry), db),
25            registry,
26            default_timeout: Duration::from_secs(30),
27        }
28    }
29
30    /// Create a new function executor with custom timeout.
31    pub fn with_timeout(
32        registry: Arc<FunctionRegistry>,
33        db: Database,
34        default_timeout: Duration,
35    ) -> Self {
36        Self {
37            router: FunctionRouter::new(Arc::clone(&registry), db),
38            registry,
39            default_timeout,
40        }
41    }
42
43    /// Create a new function executor with dispatch capabilities.
44    pub fn with_dispatch(
45        registry: Arc<FunctionRegistry>,
46        db: Database,
47        job_dispatcher: Option<Arc<dyn JobDispatch>>,
48        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
49    ) -> Self {
50        Self::with_dispatch_and_issuer(registry, db, job_dispatcher, workflow_dispatcher, None)
51    }
52
53    /// Create a function executor with dispatch and token issuer.
54    pub fn with_dispatch_and_issuer(
55        registry: Arc<FunctionRegistry>,
56        db: Database,
57        job_dispatcher: Option<Arc<dyn JobDispatch>>,
58        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
59        token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
60    ) -> Self {
61        let mut router = FunctionRouter::new(Arc::clone(&registry), db);
62        if let Some(jd) = job_dispatcher {
63            router = router.with_job_dispatcher(jd);
64        }
65        if let Some(wd) = workflow_dispatcher {
66            router = router.with_workflow_dispatcher(wd);
67        }
68        if let Some(issuer) = token_issuer {
69            router = router.with_token_issuer(issuer);
70        }
71        Self {
72            router,
73            registry,
74            default_timeout: Duration::from_secs(30),
75        }
76    }
77
78    /// Set the token TTL config on the underlying router.
79    pub fn set_token_ttl(&mut self, ttl: forge_core::AuthTokenTtl) {
80        self.router.set_token_ttl(ttl);
81    }
82
83    /// Execute a function call.
84    pub async fn execute(
85        &self,
86        function_name: &str,
87        args: Value,
88        auth: AuthContext,
89        request: RequestMetadata,
90    ) -> Result<ExecutionResult> {
91        let start = std::time::Instant::now();
92        let fn_timeout = self.get_function_timeout(function_name);
93        let log_level = self.get_function_log_level(function_name);
94
95        let kind = self
96            .router
97            .get_function_kind(function_name)
98            .map(|k| k.to_string())
99            .unwrap_or_else(|| "unknown".to_string());
100
101        let span = tracing::info_span!(
102            "fn.execute",
103            function = function_name,
104            fn.kind = %kind,
105        );
106
107        let result = match timeout(
108            fn_timeout,
109            self.router
110                .route(function_name, args.clone(), auth, request)
111                .instrument(span),
112        )
113        .await
114        {
115            Ok(result) => result,
116            Err(_) => {
117                let duration = start.elapsed();
118                self.log_execution(
119                    log_level,
120                    function_name,
121                    "unknown",
122                    &args,
123                    duration,
124                    false,
125                    Some(&format!("Timeout after {:?}", fn_timeout)),
126                );
127                crate::observability::record_fn_execution(
128                    function_name,
129                    &kind,
130                    false,
131                    duration.as_secs_f64(),
132                );
133                return Err(ForgeError::Timeout(format!(
134                    "Function '{}' timed out after {:?}",
135                    function_name, fn_timeout
136                )));
137            }
138        };
139
140        let duration = start.elapsed();
141
142        match result {
143            Ok(route_result) => {
144                let (result_kind, value) = match route_result {
145                    RouteResult::Query(v) => ("query", v),
146                    RouteResult::Mutation(v) => ("mutation", v),
147                    RouteResult::Job(v) => ("job", v),
148                    RouteResult::Workflow(v) => ("workflow", v),
149                };
150
151                self.log_execution(
152                    log_level,
153                    function_name,
154                    result_kind,
155                    &args,
156                    duration,
157                    true,
158                    None,
159                );
160                crate::observability::record_fn_execution(
161                    function_name,
162                    result_kind,
163                    true,
164                    duration.as_secs_f64(),
165                );
166
167                Ok(ExecutionResult {
168                    function_name: function_name.to_string(),
169                    function_kind: result_kind.to_string(),
170                    result: value,
171                    duration,
172                    success: true,
173                    error: None,
174                })
175            }
176            Err(e) => {
177                self.log_execution(
178                    log_level,
179                    function_name,
180                    &kind,
181                    &args,
182                    duration,
183                    false,
184                    Some(&e.to_string()),
185                );
186                crate::observability::record_fn_execution(
187                    function_name,
188                    &kind,
189                    false,
190                    duration.as_secs_f64(),
191                );
192
193                Err(e)
194            }
195        }
196    }
197
198    /// Log function execution at the configured level.
199    #[allow(clippy::too_many_arguments)]
200    fn log_execution(
201        &self,
202        log_level: &str,
203        function_name: &str,
204        kind: &str,
205        input: &Value,
206        duration: Duration,
207        success: bool,
208        error: Option<&str>,
209    ) {
210        // Failures are always logged at error regardless of the function's
211        // configured log level. Successes use the configured level.
212        if !success {
213            error!(
214                function = function_name,
215                kind = kind,
216                duration_ms = duration.as_millis() as u64,
217                error = error,
218                "Function failed"
219            );
220            debug!(
221                function = function_name,
222                input = %input,
223                "Function input"
224            );
225            return;
226        }
227
228        macro_rules! log_fn {
229            ($level:ident) => {{
230                $level!(
231                    function = function_name,
232                    kind = kind,
233                    duration_ms = duration.as_millis() as u64,
234                    "Function executed"
235                );
236                debug!(
237                    function = function_name,
238                    input = %input,
239                    "Function input"
240                );
241            }};
242        }
243
244        match log_level {
245            "off" => {}
246            "error" => log_fn!(error),
247            "warn" => log_fn!(warn),
248            "info" => log_fn!(info),
249            "debug" => log_fn!(debug),
250            _ => log_fn!(trace),
251        }
252    }
253
254    /// Mutations default to "info" because writes are worth tracking.
255    /// Queries default to "debug" since they're high-volume.
256    fn get_function_log_level(&self, function_name: &str) -> &'static str {
257        self.registry
258            .get(function_name)
259            .map(|entry| {
260                entry.info().log_level.unwrap_or(match entry.kind() {
261                    forge_core::FunctionKind::Mutation => "info",
262                    forge_core::FunctionKind::Query => "debug",
263                })
264            })
265            .unwrap_or("info")
266    }
267
268    /// Get the timeout for a specific function.
269    fn get_function_timeout(&self, function_name: &str) -> Duration {
270        self.registry
271            .get(function_name)
272            .and_then(|entry| entry.info().timeout)
273            .map(Duration::from_secs)
274            .unwrap_or(self.default_timeout)
275    }
276
277    /// Check if a function exists.
278    pub fn has_function(&self, function_name: &str) -> bool {
279        self.router.has_function(function_name)
280    }
281}
282
283/// Result of executing a function.
284#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
285pub struct ExecutionResult {
286    /// Function name that was executed.
287    pub function_name: String,
288    /// Kind of function (query, mutation).
289    pub function_kind: String,
290    /// The result value (or null on error).
291    pub result: Value,
292    /// Execution duration.
293    #[serde(with = "duration_millis")]
294    pub duration: Duration,
295    /// Whether execution succeeded.
296    pub success: bool,
297    /// Error message if failed.
298    pub error: Option<String>,
299}
300
301mod duration_millis {
302    use serde::{Deserialize, Deserializer, Serializer};
303    use std::time::Duration;
304
305    pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
306    where
307        S: Serializer,
308    {
309        serializer.serialize_u64(duration.as_millis() as u64)
310    }
311
312    pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
313    where
314        D: Deserializer<'de>,
315    {
316        let millis = u64::deserialize(deserializer)?;
317        Ok(Duration::from_millis(millis))
318    }
319}
320
321#[cfg(test)]
322#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
323mod tests {
324    use super::*;
325
326    #[test]
327    fn test_execution_result_serialization() {
328        let result = ExecutionResult {
329            function_name: "get_user".to_string(),
330            function_kind: "query".to_string(),
331            result: serde_json::json!({"id": "123"}),
332            duration: Duration::from_millis(42),
333            success: true,
334            error: None,
335        };
336
337        let json = serde_json::to_string(&result).unwrap();
338        assert!(json.contains("\"duration\":42"));
339        assert!(json.contains("\"success\":true"));
340    }
341
342    #[test]
343    fn test_execution_result_round_trip() {
344        let original = ExecutionResult {
345            function_name: "create_user".to_string(),
346            function_kind: "mutation".to_string(),
347            result: serde_json::json!({"id": "456"}),
348            duration: Duration::from_millis(100),
349            success: true,
350            error: None,
351        };
352
353        let json = serde_json::to_string(&original).unwrap();
354        let deserialized: ExecutionResult = serde_json::from_str(&json).unwrap();
355
356        assert_eq!(deserialized.function_name, "create_user");
357        assert_eq!(deserialized.duration, Duration::from_millis(100));
358        assert!(deserialized.success);
359    }
360}