1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{AuthContext, ForgeError, JobDispatch, RequestMetadata, Result, WorkflowDispatch};
5use serde_json::Value;
6use tokio::time::timeout;
7use tracing::{debug, error, info, trace, warn};
8
9use super::registry::FunctionRegistry;
10use super::router::{FunctionRouter, RouteResult};
11
12pub struct FunctionExecutor {
14 router: FunctionRouter,
15 registry: Arc<FunctionRegistry>,
16 default_timeout: Duration,
17}
18
19impl FunctionExecutor {
20 pub fn new(registry: Arc<FunctionRegistry>, db_pool: sqlx::PgPool) -> Self {
22 Self {
23 router: FunctionRouter::new(Arc::clone(®istry), db_pool),
24 registry,
25 default_timeout: Duration::from_secs(30),
26 }
27 }
28
29 pub fn with_timeout(
31 registry: Arc<FunctionRegistry>,
32 db_pool: sqlx::PgPool,
33 default_timeout: Duration,
34 ) -> Self {
35 Self {
36 router: FunctionRouter::new(Arc::clone(®istry), db_pool),
37 registry,
38 default_timeout,
39 }
40 }
41
42 pub fn with_dispatch(
44 registry: Arc<FunctionRegistry>,
45 db_pool: sqlx::PgPool,
46 job_dispatcher: Option<Arc<dyn JobDispatch>>,
47 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
48 ) -> Self {
49 let mut router = FunctionRouter::new(Arc::clone(®istry), db_pool);
50 if let Some(jd) = job_dispatcher {
51 router = router.with_job_dispatcher(jd);
52 }
53 if let Some(wd) = workflow_dispatcher {
54 router = router.with_workflow_dispatcher(wd);
55 }
56 Self {
57 router,
58 registry,
59 default_timeout: Duration::from_secs(30),
60 }
61 }
62
63 pub async fn execute(
65 &self,
66 function_name: &str,
67 args: Value,
68 auth: AuthContext,
69 request: RequestMetadata,
70 ) -> Result<ExecutionResult> {
71 let start = std::time::Instant::now();
72
73 let fn_timeout = self.get_function_timeout(function_name);
75
76 let log_level = self.get_function_log_level(function_name);
78
79 let result = match timeout(
81 fn_timeout,
82 self.router
83 .route(function_name, args.clone(), auth, request),
84 )
85 .await
86 {
87 Ok(result) => result,
88 Err(_) => {
89 let duration = start.elapsed();
90 self.log_execution(
91 log_level,
92 function_name,
93 "unknown",
94 &args,
95 duration,
96 false,
97 Some(&format!("Timeout after {:?}", fn_timeout)),
98 );
99 return Err(ForgeError::Timeout(format!(
100 "Function '{}' timed out after {:?}",
101 function_name, fn_timeout
102 )));
103 }
104 };
105
106 let duration = start.elapsed();
107
108 match result {
109 Ok(route_result) => {
110 let (kind, value) = match route_result {
111 RouteResult::Query(v) => ("query", v),
112 RouteResult::Mutation(v) => ("mutation", v),
113 RouteResult::Job(v) => ("job", v),
114 RouteResult::Workflow(v) => ("workflow", v),
115 };
116
117 self.log_execution(log_level, function_name, kind, &args, duration, true, None);
118
119 Ok(ExecutionResult {
120 function_name: function_name.to_string(),
121 function_kind: kind.to_string(),
122 result: value,
123 duration,
124 success: true,
125 error: None,
126 })
127 }
128 Err(e) => {
129 let kind = self
130 .router
131 .get_function_kind(function_name)
132 .map(|k| k.to_string())
133 .unwrap_or_else(|| "unknown".to_string());
134
135 self.log_execution(
136 log_level,
137 function_name,
138 &kind,
139 &args,
140 duration,
141 false,
142 Some(&e.to_string()),
143 );
144
145 Err(e)
146 }
147 }
148 }
149
150 #[allow(clippy::too_many_arguments)]
152 fn log_execution(
153 &self,
154 log_level: &str,
155 function_name: &str,
156 kind: &str,
157 input: &Value,
158 duration: Duration,
159 success: bool,
160 error: Option<&str>,
161 ) {
162 macro_rules! log_fn {
163 ($level:ident) => {
164 if success {
165 $level!(
166 function = function_name,
167 kind = kind,
168 input = %input,
169 duration_ms = duration.as_millis() as u64,
170 success = success,
171 "Function executed"
172 );
173 } else {
174 $level!(
175 function = function_name,
176 kind = kind,
177 input = %input,
178 duration_ms = duration.as_millis() as u64,
179 success = success,
180 error = error,
181 "Function failed"
182 );
183 }
184 };
185 }
186
187 match log_level {
188 "off" => {}
189 "error" => log_fn!(error),
190 "warn" => log_fn!(warn),
191 "info" => log_fn!(info),
192 "debug" => log_fn!(debug),
193 _ => log_fn!(trace),
194 }
195 }
196
197 fn get_function_log_level(&self, function_name: &str) -> &'static str {
199 self.registry
200 .get(function_name)
201 .and_then(|entry| entry.info().log_level)
202 .unwrap_or("trace")
203 }
204
205 fn get_function_timeout(&self, function_name: &str) -> Duration {
207 self.registry
208 .get(function_name)
209 .and_then(|entry| entry.info().timeout)
210 .map(Duration::from_secs)
211 .unwrap_or(self.default_timeout)
212 }
213
214 pub fn has_function(&self, function_name: &str) -> bool {
216 self.router.has_function(function_name)
217 }
218}
219
220#[derive(Debug, Clone, serde::Serialize)]
222pub struct ExecutionResult {
223 pub function_name: String,
225 pub function_kind: String,
227 pub result: Value,
229 #[serde(with = "duration_millis")]
231 pub duration: Duration,
232 pub success: bool,
234 pub error: Option<String>,
236}
237
238mod duration_millis {
239 use serde::{Deserialize, Deserializer, Serializer};
240 use std::time::Duration;
241
242 pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
243 where
244 S: Serializer,
245 {
246 serializer.serialize_u64(duration.as_millis() as u64)
247 }
248
249 #[allow(dead_code)]
250 pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
251 where
252 D: Deserializer<'de>,
253 {
254 let millis = u64::deserialize(deserializer)?;
255 Ok(Duration::from_millis(millis))
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 #[test]
264 fn test_execution_result_serialization() {
265 let result = ExecutionResult {
266 function_name: "get_user".to_string(),
267 function_kind: "query".to_string(),
268 result: serde_json::json!({"id": "123"}),
269 duration: Duration::from_millis(42),
270 success: true,
271 error: None,
272 };
273
274 let json = serde_json::to_string(&result).unwrap();
275 assert!(json.contains("\"duration\":42"));
276 assert!(json.contains("\"success\":true"));
277 }
278}