1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{AuthContext, ForgeError, JobDispatch, RequestMetadata, Result, WorkflowDispatch};
5use serde_json::Value;
6use tokio::time::timeout;
7use tracing::{debug, error, info, trace, warn};
8
9use super::registry::FunctionRegistry;
10use super::router::{FunctionRouter, RouteResult};
11use crate::db::Database;
12
13pub struct FunctionExecutor {
15 router: FunctionRouter,
16 registry: Arc<FunctionRegistry>,
17 default_timeout: Duration,
18}
19
20impl FunctionExecutor {
21 pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
23 Self {
24 router: FunctionRouter::new(Arc::clone(®istry), db),
25 registry,
26 default_timeout: Duration::from_secs(30),
27 }
28 }
29
30 pub fn with_timeout(
32 registry: Arc<FunctionRegistry>,
33 db: Database,
34 default_timeout: Duration,
35 ) -> Self {
36 Self {
37 router: FunctionRouter::new(Arc::clone(®istry), db),
38 registry,
39 default_timeout,
40 }
41 }
42
43 pub fn with_dispatch(
45 registry: Arc<FunctionRegistry>,
46 db: Database,
47 job_dispatcher: Option<Arc<dyn JobDispatch>>,
48 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
49 ) -> Self {
50 let mut router = FunctionRouter::new(Arc::clone(®istry), db);
51 if let Some(jd) = job_dispatcher {
52 router = router.with_job_dispatcher(jd);
53 }
54 if let Some(wd) = workflow_dispatcher {
55 router = router.with_workflow_dispatcher(wd);
56 }
57 Self {
58 router,
59 registry,
60 default_timeout: Duration::from_secs(30),
61 }
62 }
63
64 pub async fn execute(
66 &self,
67 function_name: &str,
68 args: Value,
69 auth: AuthContext,
70 request: RequestMetadata,
71 ) -> Result<ExecutionResult> {
72 let start = std::time::Instant::now();
73
74 let fn_timeout = self.get_function_timeout(function_name);
76
77 let log_level = self.get_function_log_level(function_name);
79
80 let result = match timeout(
82 fn_timeout,
83 self.router
84 .route(function_name, args.clone(), auth, request),
85 )
86 .await
87 {
88 Ok(result) => result,
89 Err(_) => {
90 let duration = start.elapsed();
91 self.log_execution(
92 log_level,
93 function_name,
94 "unknown",
95 &args,
96 duration,
97 false,
98 Some(&format!("Timeout after {:?}", fn_timeout)),
99 );
100 return Err(ForgeError::Timeout(format!(
101 "Function '{}' timed out after {:?}",
102 function_name, fn_timeout
103 )));
104 }
105 };
106
107 let duration = start.elapsed();
108
109 match result {
110 Ok(route_result) => {
111 let (kind, value) = match route_result {
112 RouteResult::Query(v) => ("query", v),
113 RouteResult::Mutation(v) => ("mutation", v),
114 RouteResult::Job(v) => ("job", v),
115 RouteResult::Workflow(v) => ("workflow", v),
116 };
117
118 self.log_execution(log_level, function_name, kind, &args, duration, true, None);
119
120 Ok(ExecutionResult {
121 function_name: function_name.to_string(),
122 function_kind: kind.to_string(),
123 result: value,
124 duration,
125 success: true,
126 error: None,
127 })
128 }
129 Err(e) => {
130 let kind = self
131 .router
132 .get_function_kind(function_name)
133 .map(|k| k.to_string())
134 .unwrap_or_else(|| "unknown".to_string());
135
136 self.log_execution(
137 log_level,
138 function_name,
139 &kind,
140 &args,
141 duration,
142 false,
143 Some(&e.to_string()),
144 );
145
146 Err(e)
147 }
148 }
149 }
150
151 #[allow(clippy::too_many_arguments)]
153 fn log_execution(
154 &self,
155 log_level: &str,
156 function_name: &str,
157 kind: &str,
158 input: &Value,
159 duration: Duration,
160 success: bool,
161 error: Option<&str>,
162 ) {
163 macro_rules! log_fn {
164 ($level:ident) => {
165 if success {
166 $level!(
167 function = function_name,
168 kind = kind,
169 input = %input,
170 duration_ms = duration.as_millis() as u64,
171 success = success,
172 "Function executed"
173 );
174 } else {
175 $level!(
176 function = function_name,
177 kind = kind,
178 input = %input,
179 duration_ms = duration.as_millis() as u64,
180 success = success,
181 error = error,
182 "Function failed"
183 );
184 }
185 };
186 }
187
188 match log_level {
189 "off" => {}
190 "error" => log_fn!(error),
191 "warn" => log_fn!(warn),
192 "info" => log_fn!(info),
193 "debug" => log_fn!(debug),
194 _ => log_fn!(trace),
195 }
196 }
197
198 fn get_function_log_level(&self, function_name: &str) -> &'static str {
200 self.registry
201 .get(function_name)
202 .and_then(|entry| entry.info().log_level)
203 .unwrap_or("trace")
204 }
205
206 fn get_function_timeout(&self, function_name: &str) -> Duration {
208 self.registry
209 .get(function_name)
210 .and_then(|entry| entry.info().timeout)
211 .map(Duration::from_secs)
212 .unwrap_or(self.default_timeout)
213 }
214
215 pub fn has_function(&self, function_name: &str) -> bool {
217 self.router.has_function(function_name)
218 }
219}
220
221#[derive(Debug, Clone, serde::Serialize)]
223pub struct ExecutionResult {
224 pub function_name: String,
226 pub function_kind: String,
228 pub result: Value,
230 #[serde(with = "duration_millis")]
232 pub duration: Duration,
233 pub success: bool,
235 pub error: Option<String>,
237}
238
239mod duration_millis {
240 use serde::{Deserialize, Deserializer, Serializer};
241 use std::time::Duration;
242
243 pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
244 where
245 S: Serializer,
246 {
247 serializer.serialize_u64(duration.as_millis() as u64)
248 }
249
250 #[allow(dead_code)]
251 pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
252 where
253 D: Deserializer<'de>,
254 {
255 let millis = u64::deserialize(deserializer)?;
256 Ok(Duration::from_millis(millis))
257 }
258}
259
260#[cfg(test)]
261#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
262mod tests {
263 use super::*;
264
265 #[test]
266 fn test_execution_result_serialization() {
267 let result = ExecutionResult {
268 function_name: "get_user".to_string(),
269 function_kind: "query".to_string(),
270 result: serde_json::json!({"id": "123"}),
271 duration: Duration::from_millis(42),
272 success: true,
273 error: None,
274 };
275
276 let json = serde_json::to_string(&result).unwrap();
277 assert!(json.contains("\"duration\":42"));
278 assert!(json.contains("\"success\":true"));
279 }
280}