1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{AuthContext, ForgeError, JobDispatch, RequestMetadata, Result, WorkflowDispatch};
5use serde_json::Value;
6use tokio::time::timeout;
7
8use super::registry::FunctionRegistry;
9use super::router::{FunctionRouter, RouteResult};
10
11pub struct FunctionExecutor {
13 router: FunctionRouter,
14 registry: Arc<FunctionRegistry>,
15 default_timeout: Duration,
16}
17
18impl FunctionExecutor {
19 pub fn new(registry: Arc<FunctionRegistry>, db_pool: sqlx::PgPool) -> Self {
21 Self {
22 router: FunctionRouter::new(Arc::clone(®istry), db_pool),
23 registry,
24 default_timeout: Duration::from_secs(30),
25 }
26 }
27
28 pub fn with_timeout(
30 registry: Arc<FunctionRegistry>,
31 db_pool: sqlx::PgPool,
32 default_timeout: Duration,
33 ) -> Self {
34 Self {
35 router: FunctionRouter::new(Arc::clone(®istry), db_pool),
36 registry,
37 default_timeout,
38 }
39 }
40
41 pub fn with_dispatch(
43 registry: Arc<FunctionRegistry>,
44 db_pool: sqlx::PgPool,
45 job_dispatcher: Option<Arc<dyn JobDispatch>>,
46 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
47 ) -> Self {
48 let mut router = FunctionRouter::new(Arc::clone(®istry), db_pool);
49 if let Some(jd) = job_dispatcher {
50 router = router.with_job_dispatcher(jd);
51 }
52 if let Some(wd) = workflow_dispatcher {
53 router = router.with_workflow_dispatcher(wd);
54 }
55 Self {
56 router,
57 registry,
58 default_timeout: Duration::from_secs(30),
59 }
60 }
61
62 pub async fn execute(
64 &self,
65 function_name: &str,
66 args: Value,
67 auth: AuthContext,
68 request: RequestMetadata,
69 ) -> Result<ExecutionResult> {
70 let start = std::time::Instant::now();
71
72 let fn_timeout = self.get_function_timeout(function_name);
74
75 let result = match timeout(
77 fn_timeout,
78 self.router.route(function_name, args, auth, request),
79 )
80 .await
81 {
82 Ok(result) => result,
83 Err(_) => {
84 return Err(ForgeError::Timeout(format!(
85 "Function '{}' timed out after {:?}",
86 function_name, fn_timeout
87 )));
88 }
89 };
90
91 let duration = start.elapsed();
92
93 match result {
94 Ok(route_result) => {
95 let (kind, value) = match route_result {
96 RouteResult::Query(v) => ("query", v),
97 RouteResult::Mutation(v) => ("mutation", v),
98 RouteResult::Action(v) => ("action", v),
99 };
100
101 Ok(ExecutionResult {
102 function_name: function_name.to_string(),
103 function_kind: kind.to_string(),
104 result: value,
105 duration,
106 success: true,
107 error: None,
108 })
109 }
110 Err(e) => Ok(ExecutionResult {
111 function_name: function_name.to_string(),
112 function_kind: self
113 .router
114 .get_function_kind(function_name)
115 .map(|k| k.to_string())
116 .unwrap_or_else(|| "unknown".to_string()),
117 result: Value::Null,
118 duration,
119 success: false,
120 error: Some(e.to_string()),
121 }),
122 }
123 }
124
125 fn get_function_timeout(&self, function_name: &str) -> Duration {
127 self.registry
128 .get(function_name)
129 .and_then(|entry| entry.info().timeout)
130 .map(Duration::from_secs)
131 .unwrap_or(self.default_timeout)
132 }
133
134 pub fn has_function(&self, function_name: &str) -> bool {
136 self.router.has_function(function_name)
137 }
138}
139
140#[derive(Debug, Clone, serde::Serialize)]
142pub struct ExecutionResult {
143 pub function_name: String,
145 pub function_kind: String,
147 pub result: Value,
149 #[serde(with = "duration_millis")]
151 pub duration: Duration,
152 pub success: bool,
154 pub error: Option<String>,
156}
157
158mod duration_millis {
159 use serde::{Deserialize, Deserializer, Serializer};
160 use std::time::Duration;
161
162 pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
163 where
164 S: Serializer,
165 {
166 serializer.serialize_u64(duration.as_millis() as u64)
167 }
168
169 #[allow(dead_code)]
170 pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
171 where
172 D: Deserializer<'de>,
173 {
174 let millis = u64::deserialize(deserializer)?;
175 Ok(Duration::from_millis(millis))
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
184 fn test_execution_result_serialization() {
185 let result = ExecutionResult {
186 function_name: "get_user".to_string(),
187 function_kind: "query".to_string(),
188 result: serde_json::json!({"id": "123"}),
189 duration: Duration::from_millis(42),
190 success: true,
191 error: None,
192 };
193
194 let json = serde_json::to_string(&result).unwrap();
195 assert!(json.contains("\"duration\":42"));
196 assert!(json.contains("\"success\":true"));
197 }
198}