1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{AuthContext, ForgeError, JobDispatch, RequestMetadata, Result, WorkflowDispatch};
5use serde_json::Value;
6use tokio::time::timeout;
7use tracing::{debug, error, info, trace, warn};
8
9use super::registry::FunctionRegistry;
10use super::router::{FunctionRouter, RouteResult};
11
12pub struct FunctionExecutor {
14 router: FunctionRouter,
15 registry: Arc<FunctionRegistry>,
16 default_timeout: Duration,
17}
18
19impl FunctionExecutor {
20 pub fn new(registry: Arc<FunctionRegistry>, db_pool: sqlx::PgPool) -> Self {
22 Self {
23 router: FunctionRouter::new(Arc::clone(®istry), db_pool),
24 registry,
25 default_timeout: Duration::from_secs(30),
26 }
27 }
28
29 pub fn with_timeout(
31 registry: Arc<FunctionRegistry>,
32 db_pool: sqlx::PgPool,
33 default_timeout: Duration,
34 ) -> Self {
35 Self {
36 router: FunctionRouter::new(Arc::clone(®istry), db_pool),
37 registry,
38 default_timeout,
39 }
40 }
41
42 pub fn with_dispatch(
44 registry: Arc<FunctionRegistry>,
45 db_pool: sqlx::PgPool,
46 job_dispatcher: Option<Arc<dyn JobDispatch>>,
47 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
48 ) -> Self {
49 let mut router = FunctionRouter::new(Arc::clone(®istry), db_pool);
50 if let Some(jd) = job_dispatcher {
51 router = router.with_job_dispatcher(jd);
52 }
53 if let Some(wd) = workflow_dispatcher {
54 router = router.with_workflow_dispatcher(wd);
55 }
56 Self {
57 router,
58 registry,
59 default_timeout: Duration::from_secs(30),
60 }
61 }
62
63 pub async fn execute(
65 &self,
66 function_name: &str,
67 args: Value,
68 auth: AuthContext,
69 request: RequestMetadata,
70 ) -> Result<ExecutionResult> {
71 let start = std::time::Instant::now();
72
73 let fn_timeout = self.get_function_timeout(function_name);
75
76 let log_level = self.get_function_log_level(function_name);
78
79 let result = match timeout(
81 fn_timeout,
82 self.router
83 .route(function_name, args.clone(), auth, request),
84 )
85 .await
86 {
87 Ok(result) => result,
88 Err(_) => {
89 let duration = start.elapsed();
90 self.log_execution(
91 log_level,
92 function_name,
93 "unknown",
94 &args,
95 duration,
96 false,
97 Some(&format!("Timeout after {:?}", fn_timeout)),
98 );
99 return Err(ForgeError::Timeout(format!(
100 "Function '{}' timed out after {:?}",
101 function_name, fn_timeout
102 )));
103 }
104 };
105
106 let duration = start.elapsed();
107
108 match result {
109 Ok(route_result) => {
110 let (kind, value) = match route_result {
111 RouteResult::Query(v) => ("query", v),
112 RouteResult::Mutation(v) => ("mutation", v),
113 };
114
115 self.log_execution(log_level, function_name, kind, &args, duration, true, None);
116
117 Ok(ExecutionResult {
118 function_name: function_name.to_string(),
119 function_kind: kind.to_string(),
120 result: value,
121 duration,
122 success: true,
123 error: None,
124 })
125 }
126 Err(e) => {
127 let kind = self
128 .router
129 .get_function_kind(function_name)
130 .map(|k| k.to_string())
131 .unwrap_or_else(|| "unknown".to_string());
132
133 self.log_execution(
134 log_level,
135 function_name,
136 &kind,
137 &args,
138 duration,
139 false,
140 Some(&e.to_string()),
141 );
142
143 Ok(ExecutionResult {
144 function_name: function_name.to_string(),
145 function_kind: kind,
146 result: Value::Null,
147 duration,
148 success: false,
149 error: Some(e.to_string()),
150 })
151 }
152 }
153 }
154
155 #[allow(clippy::too_many_arguments)]
157 fn log_execution(
158 &self,
159 log_level: &str,
160 function_name: &str,
161 kind: &str,
162 input: &Value,
163 duration: Duration,
164 success: bool,
165 error: Option<&str>,
166 ) {
167 let duration_ms = duration.as_millis();
168 let input_str = input.to_string();
169
170 match log_level {
171 "off" => {}
172 "error" => {
173 if success {
174 error!(
175 function = function_name,
176 kind = kind,
177 input = input_str,
178 duration_ms = duration_ms,
179 success = success,
180 "Function executed"
181 );
182 } else {
183 error!(
184 function = function_name,
185 kind = kind,
186 input = input_str,
187 duration_ms = duration_ms,
188 success = success,
189 error = error,
190 "Function failed"
191 );
192 }
193 }
194 "warn" => {
195 if success {
196 warn!(
197 function = function_name,
198 kind = kind,
199 input = input_str,
200 duration_ms = duration_ms,
201 success = success,
202 "Function executed"
203 );
204 } else {
205 warn!(
206 function = function_name,
207 kind = kind,
208 input = input_str,
209 duration_ms = duration_ms,
210 success = success,
211 error = error,
212 "Function failed"
213 );
214 }
215 }
216 "info" => {
217 if success {
218 info!(
219 function = function_name,
220 kind = kind,
221 input = input_str,
222 duration_ms = duration_ms,
223 success = success,
224 "Function executed"
225 );
226 } else {
227 info!(
228 function = function_name,
229 kind = kind,
230 input = input_str,
231 duration_ms = duration_ms,
232 success = success,
233 error = error,
234 "Function failed"
235 );
236 }
237 }
238 "debug" => {
239 if success {
240 debug!(
241 function = function_name,
242 kind = kind,
243 input = input_str,
244 duration_ms = duration_ms,
245 success = success,
246 "Function executed"
247 );
248 } else {
249 debug!(
250 function = function_name,
251 kind = kind,
252 input = input_str,
253 duration_ms = duration_ms,
254 success = success,
255 error = error,
256 "Function failed"
257 );
258 }
259 }
260 _ => {
262 if success {
263 trace!(
264 function = function_name,
265 kind = kind,
266 input = input_str,
267 duration_ms = duration_ms,
268 success = success,
269 "Function executed"
270 );
271 } else {
272 trace!(
273 function = function_name,
274 kind = kind,
275 input = input_str,
276 duration_ms = duration_ms,
277 success = success,
278 error = error,
279 "Function failed"
280 );
281 }
282 }
283 }
284 }
285
286 fn get_function_log_level(&self, function_name: &str) -> &'static str {
288 self.registry
289 .get(function_name)
290 .and_then(|entry| entry.info().log_level)
291 .unwrap_or("trace")
292 }
293
294 fn get_function_timeout(&self, function_name: &str) -> Duration {
296 self.registry
297 .get(function_name)
298 .and_then(|entry| entry.info().timeout)
299 .map(Duration::from_secs)
300 .unwrap_or(self.default_timeout)
301 }
302
303 pub fn has_function(&self, function_name: &str) -> bool {
305 self.router.has_function(function_name)
306 }
307}
308
309#[derive(Debug, Clone, serde::Serialize)]
311pub struct ExecutionResult {
312 pub function_name: String,
314 pub function_kind: String,
316 pub result: Value,
318 #[serde(with = "duration_millis")]
320 pub duration: Duration,
321 pub success: bool,
323 pub error: Option<String>,
325}
326
327mod duration_millis {
328 use serde::{Deserialize, Deserializer, Serializer};
329 use std::time::Duration;
330
331 pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
332 where
333 S: Serializer,
334 {
335 serializer.serialize_u64(duration.as_millis() as u64)
336 }
337
338 #[allow(dead_code)]
339 pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
340 where
341 D: Deserializer<'de>,
342 {
343 let millis = u64::deserialize(deserializer)?;
344 Ok(Duration::from_millis(millis))
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 #[test]
353 fn test_execution_result_serialization() {
354 let result = ExecutionResult {
355 function_name: "get_user".to_string(),
356 function_kind: "query".to_string(),
357 result: serde_json::json!({"id": "123"}),
358 duration: Duration::from_millis(42),
359 success: true,
360 error: None,
361 };
362
363 let json = serde_json::to_string(&result).unwrap();
364 assert!(json.contains("\"duration\":42"));
365 assert!(json.contains("\"success\":true"));
366 }
367}