1use std::sync::Arc;
2use std::time::Duration;
3
4use forge_core::{AuthContext, ForgeError, JobDispatch, RequestMetadata, Result, WorkflowDispatch};
5use serde_json::Value;
6use tokio::time::timeout;
7use tracing::{Instrument, debug, error, info, trace, warn};
8
9use super::registry::FunctionRegistry;
10use super::router::{FunctionRouter, RouteResult};
11use crate::db::Database;
12use crate::signals::SignalsCollector;
13
14pub struct FunctionExecutor {
16 router: FunctionRouter,
17 registry: Arc<FunctionRegistry>,
18 default_timeout: Duration,
19 signals_collector: Option<SignalsCollector>,
20 signals_server_secret: String,
21}
22
23impl FunctionExecutor {
24 pub fn new(registry: Arc<FunctionRegistry>, db: Database) -> Self {
26 Self {
27 router: FunctionRouter::new(Arc::clone(®istry), db),
28 registry,
29 default_timeout: Duration::from_secs(30),
30 signals_collector: None,
31 signals_server_secret: String::new(),
32 }
33 }
34
35 pub fn with_timeout(
37 registry: Arc<FunctionRegistry>,
38 db: Database,
39 default_timeout: Duration,
40 ) -> Self {
41 Self {
42 router: FunctionRouter::new(Arc::clone(®istry), db),
43 registry,
44 default_timeout,
45 signals_collector: None,
46 signals_server_secret: String::new(),
47 }
48 }
49
50 pub fn with_dispatch(
52 registry: Arc<FunctionRegistry>,
53 db: Database,
54 job_dispatcher: Option<Arc<dyn JobDispatch>>,
55 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
56 ) -> Self {
57 Self::with_dispatch_and_issuer(registry, db, job_dispatcher, workflow_dispatcher, None)
58 }
59
60 pub fn with_dispatch_and_issuer(
62 registry: Arc<FunctionRegistry>,
63 db: Database,
64 job_dispatcher: Option<Arc<dyn JobDispatch>>,
65 workflow_dispatcher: Option<Arc<dyn WorkflowDispatch>>,
66 token_issuer: Option<Arc<dyn forge_core::TokenIssuer>>,
67 ) -> Self {
68 let mut router = FunctionRouter::new(Arc::clone(®istry), db);
69 if let Some(jd) = job_dispatcher {
70 router = router.with_job_dispatcher(jd);
71 }
72 if let Some(wd) = workflow_dispatcher {
73 router = router.with_workflow_dispatcher(wd);
74 }
75 if let Some(issuer) = token_issuer {
76 router = router.with_token_issuer(issuer);
77 }
78 Self {
79 router,
80 registry,
81 default_timeout: Duration::from_secs(30),
82 signals_collector: None,
83 signals_server_secret: String::new(),
84 }
85 }
86
87 pub fn set_signals_collector(&mut self, collector: SignalsCollector, server_secret: String) {
89 self.signals_collector = Some(collector);
90 self.signals_server_secret = server_secret;
91 }
92
93 pub fn set_token_ttl(&mut self, ttl: forge_core::AuthTokenTtl) {
95 self.router.set_token_ttl(ttl);
96 }
97
98 pub async fn execute(
100 &self,
101 function_name: &str,
102 args: Value,
103 auth: AuthContext,
104 request: RequestMetadata,
105 ) -> Result<ExecutionResult> {
106 let start = std::time::Instant::now();
107 let fn_timeout = self.get_function_timeout(function_name);
108 let log_level = self.get_function_log_level(function_name);
109
110 let kind = self
111 .router
112 .get_function_kind(function_name)
113 .map(|k| k.to_string())
114 .unwrap_or_else(|| "unknown".to_string());
115
116 let signal_ctx = self.signals_collector.as_ref().map(|_| SignalContext {
118 user_id: auth.user_id(),
119 tenant_id: auth.tenant_id(),
120 correlation_id: request.correlation_id.clone(),
121 client_ip: request.client_ip.clone(),
122 user_agent: request.user_agent.clone(),
123 });
124
125 let span = tracing::info_span!(
126 "fn.execute",
127 function = function_name,
128 fn.kind = %kind,
129 );
130
131 let result = match timeout(
132 fn_timeout,
133 self.router
134 .route(function_name, args.clone(), auth, request)
135 .instrument(span),
136 )
137 .await
138 {
139 Ok(result) => result,
140 Err(_) => {
141 let duration = start.elapsed();
142 self.log_execution(
143 log_level,
144 function_name,
145 "unknown",
146 &args,
147 duration,
148 false,
149 Some(&format!("Timeout after {:?}", fn_timeout)),
150 );
151 crate::observability::record_fn_execution(
152 function_name,
153 &kind,
154 false,
155 duration.as_secs_f64(),
156 );
157 self.emit_signal(function_name, &kind, duration, false, &signal_ctx);
158 return Err(ForgeError::Timeout(format!(
159 "Function '{}' timed out after {:?}",
160 function_name, fn_timeout
161 )));
162 }
163 };
164
165 let duration = start.elapsed();
166
167 match result {
168 Ok(route_result) => {
169 let (result_kind, value) = match route_result {
170 RouteResult::Query(v) => ("query", v),
171 RouteResult::Mutation(v) => ("mutation", v),
172 RouteResult::Job(v) => ("job", v),
173 RouteResult::Workflow(v) => ("workflow", v),
174 };
175
176 self.log_execution(
177 log_level,
178 function_name,
179 result_kind,
180 &args,
181 duration,
182 true,
183 None,
184 );
185 crate::observability::record_fn_execution(
186 function_name,
187 result_kind,
188 true,
189 duration.as_secs_f64(),
190 );
191 self.emit_signal(function_name, result_kind, duration, true, &signal_ctx);
192
193 Ok(ExecutionResult {
194 function_name: function_name.to_string(),
195 function_kind: result_kind.to_string(),
196 result: value,
197 duration,
198 success: true,
199 error: None,
200 })
201 }
202 Err(e) => {
203 self.log_execution(
204 log_level,
205 function_name,
206 &kind,
207 &args,
208 duration,
209 false,
210 Some(&e.to_string()),
211 );
212 crate::observability::record_fn_execution(
213 function_name,
214 &kind,
215 false,
216 duration.as_secs_f64(),
217 );
218 self.emit_signal(function_name, &kind, duration, false, &signal_ctx);
219
220 Err(e)
221 }
222 }
223 }
224
225 fn emit_signal(
227 &self,
228 function_name: &str,
229 function_kind: &str,
230 duration: Duration,
231 success: bool,
232 ctx: &Option<SignalContext>,
233 ) {
234 let Some(collector) = &self.signals_collector else {
235 return;
236 };
237 let Some(ctx) = ctx else { return };
238
239 let is_bot = crate::signals::bot::is_bot(ctx.user_agent.as_deref());
240 let visitor_id = ctx.client_ip.as_ref().map(|_| {
241 crate::signals::visitor::generate_visitor_id(
242 ctx.client_ip.as_deref(),
243 ctx.user_agent.as_deref(),
244 &self.signals_server_secret,
245 )
246 });
247
248 let event = forge_core::signals::SignalEvent::rpc_call(
249 function_name,
250 function_kind,
251 duration.as_millis() as i32,
252 success,
253 ctx.user_id,
254 ctx.tenant_id,
255 ctx.correlation_id.clone(),
256 ctx.client_ip.clone(),
257 ctx.user_agent.clone(),
258 visitor_id,
259 is_bot,
260 );
261 collector.try_send(event);
262 }
263
264 #[allow(clippy::too_many_arguments)]
266 fn log_execution(
267 &self,
268 log_level: &str,
269 function_name: &str,
270 kind: &str,
271 input: &Value,
272 duration: Duration,
273 success: bool,
274 error: Option<&str>,
275 ) {
276 if !success {
279 error!(
280 function = function_name,
281 kind = kind,
282 duration_ms = duration.as_millis() as u64,
283 error = error,
284 "Function failed"
285 );
286 debug!(
287 function = function_name,
288 input = %input,
289 "Function input"
290 );
291 return;
292 }
293
294 macro_rules! log_fn {
295 ($level:ident) => {{
296 $level!(
297 function = function_name,
298 kind = kind,
299 duration_ms = duration.as_millis() as u64,
300 "Function executed"
301 );
302 debug!(
303 function = function_name,
304 input = %input,
305 "Function input"
306 );
307 }};
308 }
309
310 match log_level {
311 "off" => {}
312 "error" => log_fn!(error),
313 "warn" => log_fn!(warn),
314 "info" => log_fn!(info),
315 "debug" => log_fn!(debug),
316 _ => log_fn!(trace),
317 }
318 }
319
320 fn get_function_log_level(&self, function_name: &str) -> &'static str {
323 self.registry
324 .get(function_name)
325 .map(|entry| {
326 entry.info().log_level.unwrap_or(match entry.kind() {
327 forge_core::FunctionKind::Mutation => "info",
328 forge_core::FunctionKind::Query => "debug",
329 })
330 })
331 .unwrap_or("info")
332 }
333
334 fn get_function_timeout(&self, function_name: &str) -> Duration {
336 self.registry
337 .get(function_name)
338 .and_then(|entry| entry.info().timeout)
339 .map(Duration::from_secs)
340 .unwrap_or(self.default_timeout)
341 }
342
343 pub fn function_info(&self, function_name: &str) -> Option<forge_core::FunctionInfo> {
345 self.registry.get(function_name).map(|e| e.info().clone())
346 }
347
348 pub fn has_function(&self, function_name: &str) -> bool {
350 self.router.has_function(function_name)
351 }
352}
353
354struct SignalContext {
356 user_id: Option<uuid::Uuid>,
357 tenant_id: Option<uuid::Uuid>,
358 correlation_id: Option<String>,
359 client_ip: Option<String>,
360 user_agent: Option<String>,
361}
362
363#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
365pub struct ExecutionResult {
366 pub function_name: String,
368 pub function_kind: String,
370 pub result: Value,
372 #[serde(with = "duration_millis")]
374 pub duration: Duration,
375 pub success: bool,
377 pub error: Option<String>,
379}
380
381mod duration_millis {
382 use serde::{Deserialize, Deserializer, Serializer};
383 use std::time::Duration;
384
385 pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
386 where
387 S: Serializer,
388 {
389 serializer.serialize_u64(duration.as_millis() as u64)
390 }
391
392 pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
393 where
394 D: Deserializer<'de>,
395 {
396 let millis = u64::deserialize(deserializer)?;
397 Ok(Duration::from_millis(millis))
398 }
399}
400
401#[cfg(test)]
402#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
403mod tests {
404 use super::*;
405
406 #[test]
407 fn test_execution_result_serialization() {
408 let result = ExecutionResult {
409 function_name: "get_user".to_string(),
410 function_kind: "query".to_string(),
411 result: serde_json::json!({"id": "123"}),
412 duration: Duration::from_millis(42),
413 success: true,
414 error: None,
415 };
416
417 let json = serde_json::to_string(&result).unwrap();
418 assert!(json.contains("\"duration\":42"));
419 assert!(json.contains("\"success\":true"));
420 }
421
422 #[test]
423 fn test_execution_result_round_trip() {
424 let original = ExecutionResult {
425 function_name: "create_user".to_string(),
426 function_kind: "mutation".to_string(),
427 result: serde_json::json!({"id": "456"}),
428 duration: Duration::from_millis(100),
429 success: true,
430 error: None,
431 };
432
433 let json = serde_json::to_string(&original).unwrap();
434 let deserialized: ExecutionResult = serde_json::from_str(&json).unwrap();
435
436 assert_eq!(deserialized.function_name, "create_user");
437 assert_eq!(deserialized.duration, Duration::from_millis(100));
438 assert!(deserialized.success);
439 }
440}