1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{AuthContext, ForgeError, JobDispatch, RequestMetadata, Result, WorkflowDispatch};
5use serde_json::Value;
6use tokio::time::timeout;
7use tracing::{Instrument, debug, error, info, trace, warn};
8
9use super::registry::FunctionRegistry;
10use super::router::{FunctionRouter, RouteResult};
11use crate::db::Database;
12
13pub struct FunctionExecutor {
15 router: FunctionRouter,
16 registry: Arc<FunctionRegistry>,
17 default_timeout: Duration,
18}
19
20impl FunctionExecutor {
21 pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
23 Self {
24 router: FunctionRouter::new(Arc::clone(®istry), db),
25 registry,
26 default_timeout: Duration::from_secs(30),
27 }
28 }
29
30 pub fn with_timeout(
32 registry: Arc<FunctionRegistry>,
33 db: Database,
34 default_timeout: Duration,
35 ) -> Self {
36 Self {
37 router: FunctionRouter::new(Arc::clone(®istry), db),
38 registry,
39 default_timeout,
40 }
41 }
42
43 pub fn with_dispatch(
45 registry: Arc<FunctionRegistry>,
46 db: Database,
47 job_dispatcher: Option<Arc<dyn JobDispatch>>,
48 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
49 ) -> Self {
50 let mut router = FunctionRouter::new(Arc::clone(®istry), db);
51 if let Some(jd) = job_dispatcher {
52 router = router.with_job_dispatcher(jd);
53 }
54 if let Some(wd) = workflow_dispatcher {
55 router = router.with_workflow_dispatcher(wd);
56 }
57 Self {
58 router,
59 registry,
60 default_timeout: Duration::from_secs(30),
61 }
62 }
63
64 pub async fn execute(
66 &self,
67 function_name: &str,
68 args: Value,
69 auth: AuthContext,
70 request: RequestMetadata,
71 ) -> Result<ExecutionResult> {
72 let start = std::time::Instant::now();
73 let fn_timeout = self.get_function_timeout(function_name);
74 let log_level = self.get_function_log_level(function_name);
75
76 let kind = self
77 .router
78 .get_function_kind(function_name)
79 .map(|k| k.to_string())
80 .unwrap_or_else(|| "unknown".to_string());
81
82 let span = tracing::info_span!(
83 "fn.execute",
84 function = function_name,
85 kind = %kind,
86 );
87
88 let result = match timeout(
89 fn_timeout,
90 self.router
91 .route(function_name, args.clone(), auth, request)
92 .instrument(span),
93 )
94 .await
95 {
96 Ok(result) => result,
97 Err(_) => {
98 let duration = start.elapsed();
99 self.log_execution(
100 log_level,
101 function_name,
102 "unknown",
103 &args,
104 duration,
105 false,
106 Some(&format!("Timeout after {:?}", fn_timeout)),
107 );
108 return Err(ForgeError::Timeout(format!(
109 "Function '{}' timed out after {:?}",
110 function_name, fn_timeout
111 )));
112 }
113 };
114
115 let duration = start.elapsed();
116
117 match result {
118 Ok(route_result) => {
119 let (result_kind, value) = match route_result {
120 RouteResult::Query(v) => ("query", v),
121 RouteResult::Mutation(v) => ("mutation", v),
122 RouteResult::Job(v) => ("job", v),
123 RouteResult::Workflow(v) => ("workflow", v),
124 };
125
126 self.log_execution(
127 log_level,
128 function_name,
129 result_kind,
130 &args,
131 duration,
132 true,
133 None,
134 );
135
136 Ok(ExecutionResult {
137 function_name: function_name.to_string(),
138 function_kind: result_kind.to_string(),
139 result: value,
140 duration,
141 success: true,
142 error: None,
143 })
144 }
145 Err(e) => {
146 self.log_execution(
147 log_level,
148 function_name,
149 &kind,
150 &args,
151 duration,
152 false,
153 Some(&e.to_string()),
154 );
155
156 Err(e)
157 }
158 }
159 }
160
161 #[allow(clippy::too_many_arguments)]
163 fn log_execution(
164 &self,
165 log_level: &str,
166 function_name: &str,
167 kind: &str,
168 input: &Value,
169 duration: Duration,
170 success: bool,
171 error: Option<&str>,
172 ) {
173 macro_rules! log_fn {
174 ($level:ident) => {
175 if success {
176 $level!(
177 function = function_name,
178 kind = kind,
179 input = %input,
180 duration_ms = duration.as_millis() as u64,
181 success = success,
182 "Function executed"
183 );
184 } else {
185 $level!(
186 function = function_name,
187 kind = kind,
188 input = %input,
189 duration_ms = duration.as_millis() as u64,
190 success = success,
191 error = error,
192 "Function failed"
193 );
194 }
195 };
196 }
197
198 match log_level {
199 "off" => {}
200 "error" => log_fn!(error),
201 "warn" => log_fn!(warn),
202 "info" => log_fn!(info),
203 "debug" => log_fn!(debug),
204 _ => log_fn!(trace),
205 }
206 }
207
208 fn get_function_log_level(&self, function_name: &str) -> &'static str {
211 self.registry
212 .get(function_name)
213 .map(|entry| {
214 entry.info().log_level.unwrap_or(match entry.kind() {
215 forge_core::FunctionKind::Mutation => "info",
216 forge_core::FunctionKind::Query => "debug",
217 })
218 })
219 .unwrap_or("info")
220 }
221
222 fn get_function_timeout(&self, function_name: &str) -> Duration {
224 self.registry
225 .get(function_name)
226 .and_then(|entry| entry.info().timeout)
227 .map(Duration::from_secs)
228 .unwrap_or(self.default_timeout)
229 }
230
231 pub fn has_function(&self, function_name: &str) -> bool {
233 self.router.has_function(function_name)
234 }
235}
236
237#[derive(Debug, Clone, serde::Serialize)]
239pub struct ExecutionResult {
240 pub function_name: String,
242 pub function_kind: String,
244 pub result: Value,
246 #[serde(with = "duration_millis")]
248 pub duration: Duration,
249 pub success: bool,
251 pub error: Option<String>,
253}
254
255mod duration_millis {
256 use serde::{Deserialize, Deserializer, Serializer};
257 use std::time::Duration;
258
259 pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
260 where
261 S: Serializer,
262 {
263 serializer.serialize_u64(duration.as_millis() as u64)
264 }
265
266 #[allow(dead_code)]
267 pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
268 where
269 D: Deserializer<'de>,
270 {
271 let millis = u64::deserialize(deserializer)?;
272 Ok(Duration::from_millis(millis))
273 }
274}
275
276#[cfg(test)]
277#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
278mod tests {
279 use super::*;
280
281 #[test]
282 fn test_execution_result_serialization() {
283 let result = ExecutionResult {
284 function_name: "get_user".to_string(),
285 function_kind: "query".to_string(),
286 result: serde_json::json!({"id": "123"}),
287 duration: Duration::from_millis(42),
288 success: true,
289 error: None,
290 };
291
292 let json = serde_json::to_string(&result).unwrap();
293 assert!(json.contains("\"duration\":42"));
294 assert!(json.contains("\"success\":true"));
295 }
296}