forge_runtime/function/
executor.rs

1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{AuthContext, ForgeError, JobDispatch, RequestMetadata, Result, WorkflowDispatch};
5use serde_json::Value;
6use tokio::time::timeout;
7
8use super::registry::FunctionRegistry;
9use super::router::{FunctionRouter, RouteResult};
10
11/// Executes functions with timeout and error handling.
12pub struct FunctionExecutor {
13    router: FunctionRouter,
14    registry: Arc<FunctionRegistry>,
15    default_timeout: Duration,
16}
17
18impl FunctionExecutor {
19    /// Create a new function executor.
20    pub fn new(registry: Arc<FunctionRegistry>, db_pool: sqlx::PgPool) -> Self {
21        Self {
22            router: FunctionRouter::new(Arc::clone(&registry), db_pool),
23            registry,
24            default_timeout: Duration::from_secs(30),
25        }
26    }
27
28    /// Create a new function executor with custom timeout.
29    pub fn with_timeout(
30        registry: Arc<FunctionRegistry>,
31        db_pool: sqlx::PgPool,
32        default_timeout: Duration,
33    ) -> Self {
34        Self {
35            router: FunctionRouter::new(Arc::clone(&registry), db_pool),
36            registry,
37            default_timeout,
38        }
39    }
40
41    /// Create a new function executor with dispatch capabilities.
42    pub fn with_dispatch(
43        registry: Arc<FunctionRegistry>,
44        db_pool: sqlx::PgPool,
45        job_dispatcher: Option<Arc<dyn JobDispatch>>,
46        workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
47    ) -> Self {
48        let mut router = FunctionRouter::new(Arc::clone(&registry), db_pool);
49        if let Some(jd) = job_dispatcher {
50            router = router.with_job_dispatcher(jd);
51        }
52        if let Some(wd) = workflow_dispatcher {
53            router = router.with_workflow_dispatcher(wd);
54        }
55        Self {
56            router,
57            registry,
58            default_timeout: Duration::from_secs(30),
59        }
60    }
61
62    /// Execute a function call.
63    pub async fn execute(
64        &self,
65        function_name: &str,
66        args: Value,
67        auth: AuthContext,
68        request: RequestMetadata,
69    ) -> Result<ExecutionResult> {
70        let start = std::time::Instant::now();
71
72        // Get function-specific timeout or use default
73        let fn_timeout = self.get_function_timeout(function_name);
74
75        // Execute with timeout
76        let result = match timeout(
77            fn_timeout,
78            self.router.route(function_name, args, auth, request),
79        )
80        .await
81        {
82            Ok(result) => result,
83            Err(_) => {
84                return Err(ForgeError::Timeout(format!(
85                    "Function '{}' timed out after {:?}",
86                    function_name, fn_timeout
87                )));
88            }
89        };
90
91        let duration = start.elapsed();
92
93        match result {
94            Ok(route_result) => {
95                let (kind, value) = match route_result {
96                    RouteResult::Query(v) => ("query", v),
97                    RouteResult::Mutation(v) => ("mutation", v),
98                    RouteResult::Action(v) => ("action", v),
99                };
100
101                Ok(ExecutionResult {
102                    function_name: function_name.to_string(),
103                    function_kind: kind.to_string(),
104                    result: value,
105                    duration,
106                    success: true,
107                    error: None,
108                })
109            }
110            Err(e) => Ok(ExecutionResult {
111                function_name: function_name.to_string(),
112                function_kind: self
113                    .router
114                    .get_function_kind(function_name)
115                    .map(|k| k.to_string())
116                    .unwrap_or_else(|| "unknown".to_string()),
117                result: Value::Null,
118                duration,
119                success: false,
120                error: Some(e.to_string()),
121            }),
122        }
123    }
124
125    /// Get the timeout for a specific function.
126    fn get_function_timeout(&self, function_name: &str) -> Duration {
127        self.registry
128            .get(function_name)
129            .and_then(|entry| entry.info().timeout)
130            .map(Duration::from_secs)
131            .unwrap_or(self.default_timeout)
132    }
133
134    /// Check if a function exists.
135    pub fn has_function(&self, function_name: &str) -> bool {
136        self.router.has_function(function_name)
137    }
138}
139
140/// Result of executing a function.
141#[derive(Debug, Clone, serde::Serialize)]
142pub struct ExecutionResult {
143    /// Function name that was executed.
144    pub function_name: String,
145    /// Kind of function (query, mutation, action).
146    pub function_kind: String,
147    /// The result value (or null on error).
148    pub result: Value,
149    /// Execution duration.
150    #[serde(with = "duration_millis")]
151    pub duration: Duration,
152    /// Whether execution succeeded.
153    pub success: bool,
154    /// Error message if failed.
155    pub error: Option<String>,
156}
157
158mod duration_millis {
159    use serde::{Deserialize, Deserializer, Serializer};
160    use std::time::Duration;
161
162    pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
163    where
164        S: Serializer,
165    {
166        serializer.serialize_u64(duration.as_millis() as u64)
167    }
168
169    #[allow(dead_code)]
170    pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
171    where
172        D: Deserializer<'de>,
173    {
174        let millis = u64::deserialize(deserializer)?;
175        Ok(Duration::from_millis(millis))
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    #[test]
184    fn test_execution_result_serialization() {
185        let result = ExecutionResult {
186            function_name: "get_user".to_string(),
187            function_kind: "query".to_string(),
188            result: serde_json::json!({"id": "123"}),
189            duration: Duration::from_millis(42),
190            success: true,
191            error: None,
192        };
193
194        let json = serde_json::to_string(&result).unwrap();
195        assert!(json.contains("\"duration\":42"));
196        assert!(json.contains("\"success\":true"));
197    }
198}