1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{AuthContext, ForgeError, JobDispatch, RequestMetadata, Result, WorkflowDispatch};
5use serde_json::Value;
6use tokio::time::timeout;
7use tracing::{Instrument, debug, error, info, trace, warn};
8
9use super::registry::FunctionRegistry;
10use super::router::{FunctionRouter, RouteResult};
11use crate::db::Database;
12
13pub struct FunctionExecutor {
15 router: FunctionRouter,
16 registry: Arc<FunctionRegistry>,
17 default_timeout: Duration,
18}
19
20impl FunctionExecutor {
21 pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
23 Self {
24 router: FunctionRouter::new(Arc::clone(®istry), db),
25 registry,
26 default_timeout: Duration::from_secs(30),
27 }
28 }
29
30 pub fn with_timeout(
32 registry: Arc<FunctionRegistry>,
33 db: Database,
34 default_timeout: Duration,
35 ) -> Self {
36 Self {
37 router: FunctionRouter::new(Arc::clone(®istry), db),
38 registry,
39 default_timeout,
40 }
41 }
42
43 pub fn with_dispatch(
45 registry: Arc<FunctionRegistry>,
46 db: Database,
47 job_dispatcher: Option<Arc<dyn JobDispatch>>,
48 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
49 ) -> Self {
50 let mut router = FunctionRouter::new(Arc::clone(®istry), db);
51 if let Some(jd) = job_dispatcher {
52 router = router.with_job_dispatcher(jd);
53 }
54 if let Some(wd) = workflow_dispatcher {
55 router = router.with_workflow_dispatcher(wd);
56 }
57 Self {
58 router,
59 registry,
60 default_timeout: Duration::from_secs(30),
61 }
62 }
63
64 pub async fn execute(
66 &self,
67 function_name: &str,
68 args: Value,
69 auth: AuthContext,
70 request: RequestMetadata,
71 ) -> Result<ExecutionResult> {
72 let start = std::time::Instant::now();
73 let fn_timeout = self.get_function_timeout(function_name);
74 let log_level = self.get_function_log_level(function_name);
75
76 let kind = self
77 .router
78 .get_function_kind(function_name)
79 .map(|k| k.to_string())
80 .unwrap_or_else(|| "unknown".to_string());
81
82 let span = tracing::info_span!(
83 "fn.execute",
84 function = function_name,
85 kind = %kind,
86 );
87
88 let result = match timeout(
89 fn_timeout,
90 self.router
91 .route(function_name, args.clone(), auth, request)
92 .instrument(span),
93 )
94 .await
95 {
96 Ok(result) => result,
97 Err(_) => {
98 let duration = start.elapsed();
99 self.log_execution(
100 log_level,
101 function_name,
102 "unknown",
103 &args,
104 duration,
105 false,
106 Some(&format!("Timeout after {:?}", fn_timeout)),
107 );
108 crate::observability::record_fn_execution(
109 function_name,
110 &kind,
111 false,
112 duration.as_secs_f64(),
113 );
114 return Err(ForgeError::Timeout(format!(
115 "Function '{}' timed out after {:?}",
116 function_name, fn_timeout
117 )));
118 }
119 };
120
121 let duration = start.elapsed();
122
123 match result {
124 Ok(route_result) => {
125 let (result_kind, value) = match route_result {
126 RouteResult::Query(v) => ("query", v),
127 RouteResult::Mutation(v) => ("mutation", v),
128 RouteResult::Job(v) => ("job", v),
129 RouteResult::Workflow(v) => ("workflow", v),
130 };
131
132 self.log_execution(
133 log_level,
134 function_name,
135 result_kind,
136 &args,
137 duration,
138 true,
139 None,
140 );
141 crate::observability::record_fn_execution(
142 function_name,
143 result_kind,
144 true,
145 duration.as_secs_f64(),
146 );
147
148 Ok(ExecutionResult {
149 function_name: function_name.to_string(),
150 function_kind: result_kind.to_string(),
151 result: value,
152 duration,
153 success: true,
154 error: None,
155 })
156 }
157 Err(e) => {
158 self.log_execution(
159 log_level,
160 function_name,
161 &kind,
162 &args,
163 duration,
164 false,
165 Some(&e.to_string()),
166 );
167 crate::observability::record_fn_execution(
168 function_name,
169 &kind,
170 false,
171 duration.as_secs_f64(),
172 );
173
174 Err(e)
175 }
176 }
177 }
178
179 #[allow(clippy::too_many_arguments)]
181 fn log_execution(
182 &self,
183 log_level: &str,
184 function_name: &str,
185 kind: &str,
186 input: &Value,
187 duration: Duration,
188 success: bool,
189 error: Option<&str>,
190 ) {
191 macro_rules! log_fn {
192 ($level:ident) => {
193 if success {
194 $level!(
195 function = function_name,
196 kind = kind,
197 duration_ms = duration.as_millis() as u64,
198 "Function executed"
199 );
200 debug!(
201 function = function_name,
202 input = %input,
203 "Function input"
204 );
205 } else {
206 $level!(
207 function = function_name,
208 kind = kind,
209 duration_ms = duration.as_millis() as u64,
210 error = error,
211 "Function failed"
212 );
213 debug!(
214 function = function_name,
215 input = %input,
216 "Function input"
217 );
218 }
219 };
220 }
221
222 match log_level {
223 "off" => {}
224 "error" => log_fn!(error),
225 "warn" => log_fn!(warn),
226 "info" => log_fn!(info),
227 "debug" => log_fn!(debug),
228 _ => log_fn!(trace),
229 }
230 }
231
232 fn get_function_log_level(&self, function_name: &str) -> &'static str {
235 self.registry
236 .get(function_name)
237 .map(|entry| {
238 entry.info().log_level.unwrap_or(match entry.kind() {
239 forge_core::FunctionKind::Mutation => "info",
240 forge_core::FunctionKind::Query => "debug",
241 })
242 })
243 .unwrap_or("info")
244 }
245
246 fn get_function_timeout(&self, function_name: &str) -> Duration {
248 self.registry
249 .get(function_name)
250 .and_then(|entry| entry.info().timeout)
251 .map(Duration::from_secs)
252 .unwrap_or(self.default_timeout)
253 }
254
255 pub fn has_function(&self, function_name: &str) -> bool {
257 self.router.has_function(function_name)
258 }
259}
260
261#[derive(Debug, Clone, serde::Serialize)]
263pub struct ExecutionResult {
264 pub function_name: String,
266 pub function_kind: String,
268 pub result: Value,
270 #[serde(with = "duration_millis")]
272 pub duration: Duration,
273 pub success: bool,
275 pub error: Option<String>,
277}
278
279mod duration_millis {
280 use serde::{Deserialize, Deserializer, Serializer};
281 use std::time::Duration;
282
283 pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
284 where
285 S: Serializer,
286 {
287 serializer.serialize_u64(duration.as_millis() as u64)
288 }
289
290 #[allow(dead_code)]
291 pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
292 where
293 D: Deserializer<'de>,
294 {
295 let millis = u64::deserialize(deserializer)?;
296 Ok(Duration::from_millis(millis))
297 }
298}
299
300#[cfg(test)]
301#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
302mod tests {
303 use super::*;
304
305 #[test]
306 fn test_execution_result_serialization() {
307 let result = ExecutionResult {
308 function_name: "get_user".to_string(),
309 function_kind: "query".to_string(),
310 result: serde_json::json!({"id": "123"}),
311 duration: Duration::from_millis(42),
312 success: true,
313 error: None,
314 };
315
316 let json = serde_json::to_string(&result).unwrap();
317 assert!(json.contains("\"duration\":42"));
318 assert!(json.contains("\"success\":true"));
319 }
320}