1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{AuthContext, ForgeError, JobDispatch, RequestMetadata, Result, WorkflowDispatch};
5use serde_json::Value;
6use tokio::time::timeout;
7use tracing::{debug, error, info, trace, warn};
8
9use super::registry::FunctionRegistry;
10use super::router::{FunctionRouter, RouteResult};
11
12pub struct FunctionExecutor {
14 router: FunctionRouter,
15 registry: Arc<FunctionRegistry>,
16 default_timeout: Duration,
17}
18
19impl FunctionExecutor {
20 pub fn new(registry: Arc<FunctionRegistry>, db_pool: sqlx::PgPool) -> Self {
22 Self {
23 router: FunctionRouter::new(Arc::clone(®istry), db_pool),
24 registry,
25 default_timeout: Duration::from_secs(30),
26 }
27 }
28
29 pub fn with_timeout(
31 registry: Arc<FunctionRegistry>,
32 db_pool: sqlx::PgPool,
33 default_timeout: Duration,
34 ) -> Self {
35 Self {
36 router: FunctionRouter::new(Arc::clone(®istry), db_pool),
37 registry,
38 default_timeout,
39 }
40 }
41
42 pub fn with_dispatch(
44 registry: Arc<FunctionRegistry>,
45 db_pool: sqlx::PgPool,
46 job_dispatcher: Option<Arc<dyn JobDispatch>>,
47 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
48 ) -> Self {
49 let mut router = FunctionRouter::new(Arc::clone(®istry), db_pool);
50 if let Some(jd) = job_dispatcher {
51 router = router.with_job_dispatcher(jd);
52 }
53 if let Some(wd) = workflow_dispatcher {
54 router = router.with_workflow_dispatcher(wd);
55 }
56 Self {
57 router,
58 registry,
59 default_timeout: Duration::from_secs(30),
60 }
61 }
62
63 pub async fn execute(
65 &self,
66 function_name: &str,
67 args: Value,
68 auth: AuthContext,
69 request: RequestMetadata,
70 ) -> Result<ExecutionResult> {
71 let start = std::time::Instant::now();
72
73 let fn_timeout = self.get_function_timeout(function_name);
75
76 let log_level = self.get_function_log_level(function_name);
78
79 let result = match timeout(
81 fn_timeout,
82 self.router
83 .route(function_name, args.clone(), auth, request),
84 )
85 .await
86 {
87 Ok(result) => result,
88 Err(_) => {
89 let duration = start.elapsed();
90 self.log_execution(
91 log_level,
92 function_name,
93 "unknown",
94 &args,
95 duration,
96 false,
97 Some(&format!("Timeout after {:?}", fn_timeout)),
98 );
99 return Err(ForgeError::Timeout(format!(
100 "Function '{}' timed out after {:?}",
101 function_name, fn_timeout
102 )));
103 }
104 };
105
106 let duration = start.elapsed();
107
108 match result {
109 Ok(route_result) => {
110 let (kind, value) = match route_result {
111 RouteResult::Query(v) => ("query", v),
112 RouteResult::Mutation(v) => ("mutation", v),
113 RouteResult::Action(v) => ("action", v),
114 };
115
116 self.log_execution(log_level, function_name, kind, &args, duration, true, None);
117
118 Ok(ExecutionResult {
119 function_name: function_name.to_string(),
120 function_kind: kind.to_string(),
121 result: value,
122 duration,
123 success: true,
124 error: None,
125 })
126 }
127 Err(e) => {
128 let kind = self
129 .router
130 .get_function_kind(function_name)
131 .map(|k| k.to_string())
132 .unwrap_or_else(|| "unknown".to_string());
133
134 self.log_execution(
135 log_level,
136 function_name,
137 &kind,
138 &args,
139 duration,
140 false,
141 Some(&e.to_string()),
142 );
143
144 Ok(ExecutionResult {
145 function_name: function_name.to_string(),
146 function_kind: kind,
147 result: Value::Null,
148 duration,
149 success: false,
150 error: Some(e.to_string()),
151 })
152 }
153 }
154 }
155
156 #[allow(clippy::too_many_arguments)]
158 fn log_execution(
159 &self,
160 log_level: &str,
161 function_name: &str,
162 kind: &str,
163 input: &Value,
164 duration: Duration,
165 success: bool,
166 error: Option<&str>,
167 ) {
168 let duration_ms = duration.as_millis();
169 let input_str = input.to_string();
170
171 match log_level {
172 "off" => {}
173 "error" => {
174 if success {
175 error!(
176 function = function_name,
177 kind = kind,
178 input = input_str,
179 duration_ms = duration_ms,
180 success = success,
181 "Function executed"
182 );
183 } else {
184 error!(
185 function = function_name,
186 kind = kind,
187 input = input_str,
188 duration_ms = duration_ms,
189 success = success,
190 error = error,
191 "Function failed"
192 );
193 }
194 }
195 "warn" => {
196 if success {
197 warn!(
198 function = function_name,
199 kind = kind,
200 input = input_str,
201 duration_ms = duration_ms,
202 success = success,
203 "Function executed"
204 );
205 } else {
206 warn!(
207 function = function_name,
208 kind = kind,
209 input = input_str,
210 duration_ms = duration_ms,
211 success = success,
212 error = error,
213 "Function failed"
214 );
215 }
216 }
217 "info" => {
218 if success {
219 info!(
220 function = function_name,
221 kind = kind,
222 input = input_str,
223 duration_ms = duration_ms,
224 success = success,
225 "Function executed"
226 );
227 } else {
228 info!(
229 function = function_name,
230 kind = kind,
231 input = input_str,
232 duration_ms = duration_ms,
233 success = success,
234 error = error,
235 "Function failed"
236 );
237 }
238 }
239 "debug" => {
240 if success {
241 debug!(
242 function = function_name,
243 kind = kind,
244 input = input_str,
245 duration_ms = duration_ms,
246 success = success,
247 "Function executed"
248 );
249 } else {
250 debug!(
251 function = function_name,
252 kind = kind,
253 input = input_str,
254 duration_ms = duration_ms,
255 success = success,
256 error = error,
257 "Function failed"
258 );
259 }
260 }
261 _ => {
263 if success {
264 trace!(
265 function = function_name,
266 kind = kind,
267 input = input_str,
268 duration_ms = duration_ms,
269 success = success,
270 "Function executed"
271 );
272 } else {
273 trace!(
274 function = function_name,
275 kind = kind,
276 input = input_str,
277 duration_ms = duration_ms,
278 success = success,
279 error = error,
280 "Function failed"
281 );
282 }
283 }
284 }
285 }
286
287 fn get_function_log_level(&self, function_name: &str) -> &'static str {
289 self.registry
290 .get(function_name)
291 .and_then(|entry| entry.info().log_level)
292 .unwrap_or("trace")
293 }
294
295 fn get_function_timeout(&self, function_name: &str) -> Duration {
297 self.registry
298 .get(function_name)
299 .and_then(|entry| entry.info().timeout)
300 .map(Duration::from_secs)
301 .unwrap_or(self.default_timeout)
302 }
303
304 pub fn has_function(&self, function_name: &str) -> bool {
306 self.router.has_function(function_name)
307 }
308}
309
310#[derive(Debug, Clone, serde::Serialize)]
312pub struct ExecutionResult {
313 pub function_name: String,
315 pub function_kind: String,
317 pub result: Value,
319 #[serde(with = "duration_millis")]
321 pub duration: Duration,
322 pub success: bool,
324 pub error: Option<String>,
326}
327
328mod duration_millis {
329 use serde::{Deserialize, Deserializer, Serializer};
330 use std::time::Duration;
331
332 pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
333 where
334 S: Serializer,
335 {
336 serializer.serialize_u64(duration.as_millis() as u64)
337 }
338
339 #[allow(dead_code)]
340 pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
341 where
342 D: Deserializer<'de>,
343 {
344 let millis = u64::deserialize(deserializer)?;
345 Ok(Duration::from_millis(millis))
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_execution_result_serialization() {
355 let result = ExecutionResult {
356 function_name: "get_user".to_string(),
357 function_kind: "query".to_string(),
358 result: serde_json::json!({"id": "123"}),
359 duration: Duration::from_millis(42),
360 success: true,
361 error: None,
362 };
363
364 let json = serde_json::to_string(&result).unwrap();
365 assert!(json.contains("\"duration\":42"));
366 assert!(json.contains("\"success\":true"));
367 }
368}