1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{AuthContext, ForgeError, JobDispatch, RequestMetadata, Result, WorkflowDispatch};
5use serde_json::Value;
6use tokio::time::timeout;
7use tracing::{Instrument, debug, error, info, trace, warn};
8
9use super::registry::FunctionRegistry;
10use super::router::{FunctionRouter, RouteResult};
11use crate::db::Database;
12
13pub struct FunctionExecutor {
15 router: FunctionRouter,
16 registry: Arc<FunctionRegistry>,
17 default_timeout: Duration,
18}
19
20impl FunctionExecutor {
21 pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
23 Self {
24 router: FunctionRouter::new(Arc::clone(®istry), db),
25 registry,
26 default_timeout: Duration::from_secs(30),
27 }
28 }
29
30 pub fn with_timeout(
32 registry: Arc<FunctionRegistry>,
33 db: Database,
34 default_timeout: Duration,
35 ) -> Self {
36 Self {
37 router: FunctionRouter::new(Arc::clone(®istry), db),
38 registry,
39 default_timeout,
40 }
41 }
42
43 pub fn with_dispatch(
45 registry: Arc<FunctionRegistry>,
46 db: Database,
47 job_dispatcher: Option<Arc<dyn JobDispatch>>,
48 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
49 ) -> Self {
50 Self::with_dispatch_and_issuer(registry, db, job_dispatcher, workflow_dispatcher, None)
51 }
52
53 pub fn with_dispatch_and_issuer(
55 registry: Arc<FunctionRegistry>,
56 db: Database,
57 job_dispatcher: Option<Arc<dyn JobDispatch>>,
58 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
59 token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
60 ) -> Self {
61 let mut router = FunctionRouter::new(Arc::clone(®istry), db);
62 if let Some(jd) = job_dispatcher {
63 router = router.with_job_dispatcher(jd);
64 }
65 if let Some(wd) = workflow_dispatcher {
66 router = router.with_workflow_dispatcher(wd);
67 }
68 if let Some(issuer) = token_issuer {
69 router = router.with_token_issuer(issuer);
70 }
71 Self {
72 router,
73 registry,
74 default_timeout: Duration::from_secs(30),
75 }
76 }
77
78 pub fn set_token_ttl(&mut self, ttl: forge_core::AuthTokenTtl) {
80 self.router.set_token_ttl(ttl);
81 }
82
83 pub async fn execute(
85 &self,
86 function_name: &str,
87 args: Value,
88 auth: AuthContext,
89 request: RequestMetadata,
90 ) -> Result<ExecutionResult> {
91 let start = std::time::Instant::now();
92 let fn_timeout = self.get_function_timeout(function_name);
93 let log_level = self.get_function_log_level(function_name);
94
95 let kind = self
96 .router
97 .get_function_kind(function_name)
98 .map(|k| k.to_string())
99 .unwrap_or_else(|| "unknown".to_string());
100
101 let span = tracing::info_span!(
102 "fn.execute",
103 function = function_name,
104 fn.kind = %kind,
105 );
106
107 let result = match timeout(
108 fn_timeout,
109 self.router
110 .route(function_name, args.clone(), auth, request)
111 .instrument(span),
112 )
113 .await
114 {
115 Ok(result) => result,
116 Err(_) => {
117 let duration = start.elapsed();
118 self.log_execution(
119 log_level,
120 function_name,
121 "unknown",
122 &args,
123 duration,
124 false,
125 Some(&format!("Timeout after {:?}", fn_timeout)),
126 );
127 crate::observability::record_fn_execution(
128 function_name,
129 &kind,
130 false,
131 duration.as_secs_f64(),
132 );
133 return Err(ForgeError::Timeout(format!(
134 "Function '{}' timed out after {:?}",
135 function_name, fn_timeout
136 )));
137 }
138 };
139
140 let duration = start.elapsed();
141
142 match result {
143 Ok(route_result) => {
144 let (result_kind, value) = match route_result {
145 RouteResult::Query(v) => ("query", v),
146 RouteResult::Mutation(v) => ("mutation", v),
147 RouteResult::Job(v) => ("job", v),
148 RouteResult::Workflow(v) => ("workflow", v),
149 };
150
151 self.log_execution(
152 log_level,
153 function_name,
154 result_kind,
155 &args,
156 duration,
157 true,
158 None,
159 );
160 crate::observability::record_fn_execution(
161 function_name,
162 result_kind,
163 true,
164 duration.as_secs_f64(),
165 );
166
167 Ok(ExecutionResult {
168 function_name: function_name.to_string(),
169 function_kind: result_kind.to_string(),
170 result: value,
171 duration,
172 success: true,
173 error: None,
174 })
175 }
176 Err(e) => {
177 self.log_execution(
178 log_level,
179 function_name,
180 &kind,
181 &args,
182 duration,
183 false,
184 Some(&e.to_string()),
185 );
186 crate::observability::record_fn_execution(
187 function_name,
188 &kind,
189 false,
190 duration.as_secs_f64(),
191 );
192
193 Err(e)
194 }
195 }
196 }
197
198 #[allow(clippy::too_many_arguments)]
200 fn log_execution(
201 &self,
202 log_level: &str,
203 function_name: &str,
204 kind: &str,
205 input: &Value,
206 duration: Duration,
207 success: bool,
208 error: Option<&str>,
209 ) {
210 if !success {
213 error!(
214 function = function_name,
215 kind = kind,
216 duration_ms = duration.as_millis() as u64,
217 error = error,
218 "Function failed"
219 );
220 debug!(
221 function = function_name,
222 input = %input,
223 "Function input"
224 );
225 return;
226 }
227
228 macro_rules! log_fn {
229 ($level:ident) => {{
230 $level!(
231 function = function_name,
232 kind = kind,
233 duration_ms = duration.as_millis() as u64,
234 "Function executed"
235 );
236 debug!(
237 function = function_name,
238 input = %input,
239 "Function input"
240 );
241 }};
242 }
243
244 match log_level {
245 "off" => {}
246 "error" => log_fn!(error),
247 "warn" => log_fn!(warn),
248 "info" => log_fn!(info),
249 "debug" => log_fn!(debug),
250 _ => log_fn!(trace),
251 }
252 }
253
254 fn get_function_log_level(&self, function_name: &str) -> &'static str {
257 self.registry
258 .get(function_name)
259 .map(|entry| {
260 entry.info().log_level.unwrap_or(match entry.kind() {
261 forge_core::FunctionKind::Mutation => "info",
262 forge_core::FunctionKind::Query => "debug",
263 })
264 })
265 .unwrap_or("info")
266 }
267
268 fn get_function_timeout(&self, function_name: &str) -> Duration {
270 self.registry
271 .get(function_name)
272 .and_then(|entry| entry.info().timeout)
273 .map(Duration::from_secs)
274 .unwrap_or(self.default_timeout)
275 }
276
277 pub fn has_function(&self, function_name: &str) -> bool {
279 self.router.has_function(function_name)
280 }
281}
282
283#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
285pub struct ExecutionResult {
286 pub function_name: String,
288 pub function_kind: String,
290 pub result: Value,
292 #[serde(with = "duration_millis")]
294 pub duration: Duration,
295 pub success: bool,
297 pub error: Option<String>,
299}
300
301mod duration_millis {
302 use serde::{Deserialize, Deserializer, Serializer};
303 use std::time::Duration;
304
305 pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
306 where
307 S: Serializer,
308 {
309 serializer.serialize_u64(duration.as_millis() as u64)
310 }
311
312 pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
313 where
314 D: Deserializer<'de>,
315 {
316 let millis = u64::deserialize(deserializer)?;
317 Ok(Duration::from_millis(millis))
318 }
319}
320
321#[cfg(test)]
322#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
323mod tests {
324 use super::*;
325
326 #[test]
327 fn test_execution_result_serialization() {
328 let result = ExecutionResult {
329 function_name: "get_user".to_string(),
330 function_kind: "query".to_string(),
331 result: serde_json::json!({"id": "123"}),
332 duration: Duration::from_millis(42),
333 success: true,
334 error: None,
335 };
336
337 let json = serde_json::to_string(&result).unwrap();
338 assert!(json.contains("\"duration\":42"));
339 assert!(json.contains("\"success\":true"));
340 }
341
342 #[test]
343 fn test_execution_result_round_trip() {
344 let original = ExecutionResult {
345 function_name: "create_user".to_string(),
346 function_kind: "mutation".to_string(),
347 result: serde_json::json!({"id": "456"}),
348 duration: Duration::from_millis(100),
349 success: true,
350 error: None,
351 };
352
353 let json = serde_json::to_string(&original).unwrap();
354 let deserialized: ExecutionResult = serde_json::from_str(&json).unwrap();
355
356 assert_eq!(deserialized.function_name, "create_user");
357 assert_eq!(deserialized.duration, Duration::from_millis(100));
358 assert!(deserialized.success);
359 }
360}