1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{AuthContext, ForgeError, JobDispatch, RequestMetadata, Result, WorkflowDispatch};
5use serde_json::Value;
6use tokio::time::timeout;
7use tracing::{Instrument, debug, error, info, trace, warn};
8
9use super::registry::FunctionRegistry;
10use super::router::{FunctionRouter, RouteResult};
11use crate::db::Database;
12
13pub struct FunctionExecutor {
15 router: FunctionRouter,
16 registry: Arc<FunctionRegistry>,
17 default_timeout: Duration,
18}
19
20impl FunctionExecutor {
21 pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
23 Self {
24 router: FunctionRouter::new(Arc::clone(®istry), db),
25 registry,
26 default_timeout: Duration::from_secs(30),
27 }
28 }
29
30 pub fn with_timeout(
32 registry: Arc<FunctionRegistry>,
33 db: Database,
34 default_timeout: Duration,
35 ) -> Self {
36 Self {
37 router: FunctionRouter::new(Arc::clone(®istry), db),
38 registry,
39 default_timeout,
40 }
41 }
42
43 pub fn with_dispatch(
45 registry: Arc<FunctionRegistry>,
46 db: Database,
47 job_dispatcher: Option<Arc<dyn JobDispatch>>,
48 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
49 ) -> Self {
50 Self::with_dispatch_and_issuer(registry, db, job_dispatcher, workflow_dispatcher, None)
51 }
52
53 pub fn with_dispatch_and_issuer(
55 registry: Arc<FunctionRegistry>,
56 db: Database,
57 job_dispatcher: Option<Arc<dyn JobDispatch>>,
58 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
59 token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
60 ) -> Self {
61 let mut router = FunctionRouter::new(Arc::clone(®istry), db);
62 if let Some(jd) = job_dispatcher {
63 router = router.with_job_dispatcher(jd);
64 }
65 if let Some(wd) = workflow_dispatcher {
66 router = router.with_workflow_dispatcher(wd);
67 }
68 if let Some(issuer) = token_issuer {
69 router = router.with_token_issuer(issuer);
70 }
71 Self {
72 router,
73 registry,
74 default_timeout: Duration::from_secs(30),
75 }
76 }
77
78 pub async fn execute(
80 &self,
81 function_name: &str,
82 args: Value,
83 auth: AuthContext,
84 request: RequestMetadata,
85 ) -> Result<ExecutionResult> {
86 let start = std::time::Instant::now();
87 let fn_timeout = self.get_function_timeout(function_name);
88 let log_level = self.get_function_log_level(function_name);
89
90 let kind = self
91 .router
92 .get_function_kind(function_name)
93 .map(|k| k.to_string())
94 .unwrap_or_else(|| "unknown".to_string());
95
96 let span = tracing::info_span!(
97 "fn.execute",
98 function = function_name,
99 kind = %kind,
100 );
101
102 let result = match timeout(
103 fn_timeout,
104 self.router
105 .route(function_name, args.clone(), auth, request)
106 .instrument(span),
107 )
108 .await
109 {
110 Ok(result) => result,
111 Err(_) => {
112 let duration = start.elapsed();
113 self.log_execution(
114 log_level,
115 function_name,
116 "unknown",
117 &args,
118 duration,
119 false,
120 Some(&format!("Timeout after {:?}", fn_timeout)),
121 );
122 crate::observability::record_fn_execution(
123 function_name,
124 &kind,
125 false,
126 duration.as_secs_f64(),
127 );
128 return Err(ForgeError::Timeout(format!(
129 "Function '{}' timed out after {:?}",
130 function_name, fn_timeout
131 )));
132 }
133 };
134
135 let duration = start.elapsed();
136
137 match result {
138 Ok(route_result) => {
139 let (result_kind, value) = match route_result {
140 RouteResult::Query(v) => ("query", v),
141 RouteResult::Mutation(v) => ("mutation", v),
142 RouteResult::Job(v) => ("job", v),
143 RouteResult::Workflow(v) => ("workflow", v),
144 };
145
146 self.log_execution(
147 log_level,
148 function_name,
149 result_kind,
150 &args,
151 duration,
152 true,
153 None,
154 );
155 crate::observability::record_fn_execution(
156 function_name,
157 result_kind,
158 true,
159 duration.as_secs_f64(),
160 );
161
162 Ok(ExecutionResult {
163 function_name: function_name.to_string(),
164 function_kind: result_kind.to_string(),
165 result: value,
166 duration,
167 success: true,
168 error: None,
169 })
170 }
171 Err(e) => {
172 self.log_execution(
173 log_level,
174 function_name,
175 &kind,
176 &args,
177 duration,
178 false,
179 Some(&e.to_string()),
180 );
181 crate::observability::record_fn_execution(
182 function_name,
183 &kind,
184 false,
185 duration.as_secs_f64(),
186 );
187
188 Err(e)
189 }
190 }
191 }
192
193 #[allow(clippy::too_many_arguments)]
195 fn log_execution(
196 &self,
197 log_level: &str,
198 function_name: &str,
199 kind: &str,
200 input: &Value,
201 duration: Duration,
202 success: bool,
203 error: Option<&str>,
204 ) {
205 if !success {
208 error!(
209 function = function_name,
210 kind = kind,
211 duration_ms = duration.as_millis() as u64,
212 error = error,
213 "Function failed"
214 );
215 debug!(
216 function = function_name,
217 input = %input,
218 "Function input"
219 );
220 return;
221 }
222
223 macro_rules! log_fn {
224 ($level:ident) => {{
225 $level!(
226 function = function_name,
227 kind = kind,
228 duration_ms = duration.as_millis() as u64,
229 "Function executed"
230 );
231 debug!(
232 function = function_name,
233 input = %input,
234 "Function input"
235 );
236 }};
237 }
238
239 match log_level {
240 "off" => {}
241 "error" => log_fn!(error),
242 "warn" => log_fn!(warn),
243 "info" => log_fn!(info),
244 "debug" => log_fn!(debug),
245 _ => log_fn!(trace),
246 }
247 }
248
249 fn get_function_log_level(&self, function_name: &str) -> &'static str {
252 self.registry
253 .get(function_name)
254 .map(|entry| {
255 entry.info().log_level.unwrap_or(match entry.kind() {
256 forge_core::FunctionKind::Mutation => "info",
257 forge_core::FunctionKind::Query => "debug",
258 })
259 })
260 .unwrap_or("info")
261 }
262
263 fn get_function_timeout(&self, function_name: &str) -> Duration {
265 self.registry
266 .get(function_name)
267 .and_then(|entry| entry.info().timeout)
268 .map(Duration::from_secs)
269 .unwrap_or(self.default_timeout)
270 }
271
272 pub fn has_function(&self, function_name: &str) -> bool {
274 self.router.has_function(function_name)
275 }
276}
277
278#[derive(Debug, Clone, serde::Serialize)]
280pub struct ExecutionResult {
281 pub function_name: String,
283 pub function_kind: String,
285 pub result: Value,
287 #[serde(with = "duration_millis")]
289 pub duration: Duration,
290 pub success: bool,
292 pub error: Option<String>,
294}
295
296mod duration_millis {
297 use serde::{Deserialize, Deserializer, Serializer};
298 use std::time::Duration;
299
300 pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
301 where
302 S: Serializer,
303 {
304 serializer.serialize_u64(duration.as_millis() as u64)
305 }
306
307 #[allow(dead_code)]
308 pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
309 where
310 D: Deserializer<'de>,
311 {
312 let millis = u64::deserialize(deserializer)?;
313 Ok(Duration::from_millis(millis))
314 }
315}
316
317#[cfg(test)]
318#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
319mod tests {
320 use super::*;
321
322 #[test]
323 fn test_execution_result_serialization() {
324 let result = ExecutionResult {
325 function_name: "get_user".to_string(),
326 function_kind: "query".to_string(),
327 result: serde_json::json!({"id": "123"}),
328 duration: Duration::from_millis(42),
329 success: true,
330 error: None,
331 };
332
333 let json = serde_json::to_string(&result).unwrap();
334 assert!(json.contains("\"duration\":42"));
335 assert!(json.contains("\"success\":true"));
336 }
337}