1use crate::request_chaining::{
7 ChainConfig, ChainDefinition, ChainExecutionContext, ChainLink, ChainResponse,
8 ChainTemplatingContext, RequestChainRegistry,
9};
10use crate::request_scripting::{ScriptContext, ScriptEngine};
11use crate::templating::{expand_str_with_context, TemplatingContext};
12use crate::{Error, Result};
13use chrono::Utc;
14use futures::future::join_all;
15use reqwest::{
16 header::{HeaderMap, HeaderName, HeaderValue},
17 Client, Method,
18};
19use serde_json::Value;
20use std::collections::{HashMap, HashSet};
21use std::str::FromStr;
22use std::sync::Arc;
23use tokio::sync::Mutex;
24use tokio::time::{timeout, Duration};
25
26#[derive(Debug, Clone)]
28pub struct ExecutionRecord {
29 pub executed_at: String,
31 pub result: ChainExecutionResult,
33}
34
35#[derive(Debug)]
37pub struct ChainExecutionEngine {
38 http_client: Client,
40 registry: Arc<RequestChainRegistry>,
42 config: ChainConfig,
44 execution_history: Arc<Mutex<HashMap<String, Vec<ExecutionRecord>>>>,
46 script_engine: ScriptEngine,
48}
49
50impl ChainExecutionEngine {
51 pub fn new(registry: Arc<RequestChainRegistry>, config: ChainConfig) -> Self {
58 Self::try_new(registry, config)
59 .unwrap_or_else(|e| {
60 panic!(
61 "Failed to create HTTP client for chain execution engine: {}. \
62 This typically indicates a system configuration issue (e.g., invalid timeout value).",
63 e
64 )
65 })
66 }
67
68 pub fn try_new(registry: Arc<RequestChainRegistry>, config: ChainConfig) -> Result<Self> {
72 let http_client = Client::builder()
73 .timeout(Duration::from_secs(config.global_timeout_secs))
74 .build()
75 .map_err(|e| {
76 Error::generic(format!(
77 "Failed to create HTTP client: {}. \
78 Check that the timeout value ({}) is valid.",
79 e, config.global_timeout_secs
80 ))
81 })?;
82
83 Ok(Self {
84 http_client,
85 registry,
86 config,
87 execution_history: Arc::new(Mutex::new(HashMap::new())),
88 script_engine: ScriptEngine::new(),
89 })
90 }
91
92 pub async fn execute_chain(
94 &self,
95 chain_id: &str,
96 variables: Option<serde_json::Value>,
97 ) -> Result<ChainExecutionResult> {
98 let chain = self
99 .registry
100 .get_chain(chain_id)
101 .await
102 .ok_or_else(|| Error::generic(format!("Chain '{}' not found", chain_id)))?;
103
104 let result = self.execute_chain_definition(&chain, variables).await?;
105
106 let record = ExecutionRecord {
108 executed_at: Utc::now().to_rfc3339(),
109 result: result.clone(),
110 };
111
112 let mut history = self.execution_history.lock().await;
113 history.entry(chain_id.to_string()).or_insert_with(Vec::new).push(record);
114
115 Ok(result)
116 }
117
118 pub async fn get_chain_history(&self, chain_id: &str) -> Vec<ExecutionRecord> {
120 let history = self.execution_history.lock().await;
121 history.get(chain_id).cloned().unwrap_or_default()
122 }
123
124 pub async fn execute_chain_definition(
126 &self,
127 chain_def: &ChainDefinition,
128 variables: Option<serde_json::Value>,
129 ) -> Result<ChainExecutionResult> {
130 self.registry.validate_chain(chain_def).await?;
132
133 let start_time = std::time::Instant::now();
134 let mut execution_context = ChainExecutionContext::new(chain_def.clone());
135
136 for (key, value) in &chain_def.variables {
138 execution_context
139 .templating
140 .chain_context
141 .set_variable(key.clone(), value.clone());
142 }
143
144 if let Some(serde_json::Value::Object(map)) = variables {
146 for (key, value) in map {
147 execution_context.templating.chain_context.set_variable(key, value);
148 }
149 }
150
151 if self.config.enable_parallel_execution {
152 self.execute_with_parallelism(&mut execution_context).await
153 } else {
154 self.execute_sequential(&mut execution_context).await
155 }
156 .map(|_| ChainExecutionResult {
157 chain_id: chain_def.id.clone(),
158 status: ChainExecutionStatus::Successful,
159 total_duration_ms: start_time.elapsed().as_millis() as u64,
160 request_results: execution_context.templating.chain_context.responses.clone(),
161 error_message: None,
162 })
163 }
164
165 async fn execute_with_parallelism(
167 &self,
168 execution_context: &mut ChainExecutionContext,
169 ) -> Result<()> {
170 let dep_graph = self.build_dependency_graph(&execution_context.definition.links);
171 let topo_order = self.topological_sort(&dep_graph)?;
172
173 let mut level_groups = vec![];
175 let mut processed = HashSet::new();
176
177 for request_id in topo_order {
178 if !processed.contains(&request_id) {
179 let mut level = vec![];
180 self.collect_dependency_level(request_id, &dep_graph, &mut level, &mut processed);
181 level_groups.push(level);
182 }
183 }
184
185 for level in level_groups {
187 if level.len() == 1 {
188 let request_id = &level[0];
190 let link = execution_context
191 .definition
192 .links
193 .iter()
194 .find(|l| l.request.id == *request_id)
195 .unwrap();
196
197 let link_clone = link.clone();
198 self.execute_request(&link_clone, execution_context).await?;
199 } else {
200 let tasks = level
202 .into_iter()
203 .map(|request_id| {
204 let link = execution_context
205 .definition
206 .links
207 .iter()
208 .find(|l| l.request.id == request_id)
209 .unwrap()
210 .clone();
211 let parallel_context = ChainExecutionContext {
213 definition: execution_context.definition.clone(),
214 templating: execution_context.templating.clone(),
215 start_time: std::time::Instant::now(),
216 config: execution_context.config.clone(),
217 };
218
219 let context = Arc::new(Mutex::new(parallel_context));
220 let engine =
221 ChainExecutionEngine::new(self.registry.clone(), self.config.clone());
222
223 tokio::spawn(async move {
224 let mut ctx = context.lock().await;
225 engine.execute_request(&link, &mut ctx).await
226 })
227 })
228 .collect::<Vec<_>>();
229
230 let results = join_all(tasks).await;
231 for result in results {
232 result
233 .map_err(|e| Error::generic(format!("Task join error: {}", e)))?
234 .map_err(|e| Error::generic(format!("Request execution error: {}", e)))?;
235 }
236 }
237 }
238
239 Ok(())
240 }
241
242 async fn execute_sequential(
244 &self,
245 execution_context: &mut ChainExecutionContext,
246 ) -> Result<()> {
247 let links = execution_context.definition.links.clone();
248 for link in &links {
249 self.execute_request(link, execution_context).await?;
250 }
251 Ok(())
252 }
253
254 async fn execute_request(
256 &self,
257 link: &ChainLink,
258 execution_context: &mut ChainExecutionContext,
259 ) -> Result<()> {
260 let request_start = std::time::Instant::now();
261
262 execution_context.templating.set_current_request(link.request.clone());
264
265 let method = Method::from_bytes(link.request.method.as_bytes()).map_err(|e| {
266 Error::generic(format!("Invalid HTTP method '{}': {}", link.request.method, e))
267 })?;
268
269 let url = self.expand_template(&link.request.url, &execution_context.templating);
270
271 let mut headers = HeaderMap::new();
273 for (key, value) in &link.request.headers {
274 let expanded_value = self.expand_template(value, &execution_context.templating);
275 let header_name = HeaderName::from_str(key)
276 .map_err(|e| Error::generic(format!("Invalid header name '{}': {}", key, e)))?;
277 let header_value = HeaderValue::from_str(&expanded_value).map_err(|e| {
278 Error::generic(format!("Invalid header value for '{}': {}", key, e))
279 })?;
280 headers.insert(header_name, header_value);
281 }
282
283 let mut request_builder = self.http_client.request(method, &url).headers(headers.clone());
285
286 if let Some(body) = &link.request.body {
288 match body {
289 crate::request_chaining::RequestBody::Json(json_value) => {
290 let expanded_body =
291 self.expand_template_in_json(json_value, &execution_context.templating);
292 request_builder = request_builder.json(&expanded_body);
293 }
294 crate::request_chaining::RequestBody::BinaryFile { path, content_type } => {
295 let templating_context =
297 TemplatingContext::with_chain(execution_context.templating.clone());
298
299 let expanded_path = expand_str_with_context(path, &templating_context);
301
302 let binary_body = crate::request_chaining::RequestBody::binary_file(
304 expanded_path,
305 content_type.clone(),
306 );
307
308 match binary_body.to_bytes().await {
310 Ok(file_bytes) => {
311 request_builder = request_builder.body(file_bytes);
312
313 if let Some(ct) = content_type {
315 let mut headers = headers.clone();
316 headers.insert(
317 "content-type",
318 ct.parse().unwrap_or_else(|_| {
319 reqwest::header::HeaderValue::from_static(
320 "application/octet-stream",
321 )
322 }),
323 );
324 request_builder = request_builder.headers(headers);
325 }
326 }
327 Err(e) => {
328 return Err(e);
329 }
330 }
331 }
332 }
333 }
334
335 if let Some(timeout_secs) = link.request.timeout_secs {
337 request_builder = request_builder.timeout(Duration::from_secs(timeout_secs));
338 }
339
340 if let Some(scripting) = &link.request.scripting {
342 if let Some(pre_script) = &scripting.pre_script {
343 let script_context = ScriptContext {
344 request: Some(link.request.clone()),
345 response: None,
346 chain_context: execution_context.templating.chain_context.variables.clone(),
347 variables: HashMap::new(),
348 env_vars: std::env::vars().collect(),
349 };
350
351 match self
352 .script_engine
353 .execute_script(pre_script, &script_context, scripting.timeout_ms)
354 .await
355 {
356 Ok(script_result) => {
357 for (key, value) in script_result.modified_variables {
359 execution_context.templating.chain_context.set_variable(key, value);
360 }
361 }
362 Err(e) => {
363 tracing::warn!(
364 "Pre-script execution failed for request '{}': {}",
365 link.request.id,
366 e
367 );
368 }
370 }
371 }
372 }
373
374 let response_result =
376 timeout(Duration::from_secs(self.config.global_timeout_secs), request_builder.send())
377 .await;
378
379 let response = match response_result {
380 Ok(Ok(resp)) => resp,
381 Ok(Err(e)) => {
382 return Err(Error::generic(format!("Request '{}' failed: {}", link.request.id, e)));
383 }
384 Err(_) => {
385 return Err(Error::generic(format!("Request '{}' timed out", link.request.id)));
386 }
387 };
388
389 let status = response.status();
390 let headers: HashMap<String, String> = response
391 .headers()
392 .iter()
393 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
394 .collect();
395
396 let body_text = response.text().await.unwrap_or_default();
397 let body_json: Option<Value> = serde_json::from_str(&body_text).ok();
398
399 let duration_ms = request_start.elapsed().as_millis() as u64;
400 let executed_at = Utc::now().to_rfc3339();
401
402 let chain_response = ChainResponse {
403 status: status.as_u16(),
404 headers,
405 body: body_json,
406 duration_ms,
407 executed_at,
408 error: None,
409 };
410
411 if let Some(expected) = &link.request.expected_status {
413 if !expected.contains(&status.as_u16()) {
414 let error_msg = format!(
415 "Request '{}' returned status {} but expected one of {:?}",
416 link.request.id,
417 status.as_u16(),
418 expected
419 );
420 return Err(Error::generic(error_msg));
421 }
422 }
423
424 if let Some(store_name) = &link.store_as {
426 execution_context
427 .templating
428 .chain_context
429 .store_response(store_name.clone(), chain_response.clone());
430 }
431
432 for (var_name, extraction_path) in &link.extract {
434 if let Some(value) = self.extract_from_response(&chain_response, extraction_path) {
435 execution_context.templating.chain_context.set_variable(var_name.clone(), value);
436 }
437 }
438
439 if let Some(scripting) = &link.request.scripting {
441 if let Some(post_script) = &scripting.post_script {
442 let script_context = ScriptContext {
443 request: Some(link.request.clone()),
444 response: Some(chain_response.clone()),
445 chain_context: execution_context.templating.chain_context.variables.clone(),
446 variables: HashMap::new(),
447 env_vars: std::env::vars().collect(),
448 };
449
450 match self
451 .script_engine
452 .execute_script(post_script, &script_context, scripting.timeout_ms)
453 .await
454 {
455 Ok(script_result) => {
456 for (key, value) in script_result.modified_variables {
458 execution_context.templating.chain_context.set_variable(key, value);
459 }
460 }
461 Err(e) => {
462 tracing::warn!(
463 "Post-script execution failed for request '{}': {}",
464 link.request.id,
465 e
466 );
467 }
469 }
470 }
471 }
472
473 execution_context
475 .templating
476 .chain_context
477 .store_response(link.request.id.clone(), chain_response);
478
479 Ok(())
480 }
481
482 fn build_dependency_graph(&self, links: &[ChainLink]) -> HashMap<String, Vec<String>> {
484 let mut graph = HashMap::new();
485
486 for link in links {
487 graph
488 .entry(link.request.id.clone())
489 .or_insert_with(Vec::new)
490 .extend(link.request.depends_on.iter().cloned());
491 }
492
493 graph
494 }
495
496 fn topological_sort(&self, graph: &HashMap<String, Vec<String>>) -> Result<Vec<String>> {
498 let mut visited = HashSet::new();
499 let mut rec_stack = HashSet::new();
500 let mut result = Vec::new();
501
502 for node in graph.keys() {
503 if !visited.contains(node) {
504 self.topo_sort_util(node, graph, &mut visited, &mut rec_stack, &mut result)?;
505 }
506 }
507
508 result.reverse();
509 Ok(result)
510 }
511
512 #[allow(clippy::only_used_in_recursion)]
514 fn topo_sort_util(
515 &self,
516 node: &str,
517 graph: &HashMap<String, Vec<String>>,
518 visited: &mut HashSet<String>,
519 rec_stack: &mut HashSet<String>,
520 result: &mut Vec<String>,
521 ) -> Result<()> {
522 visited.insert(node.to_string());
523 rec_stack.insert(node.to_string());
524
525 if let Some(dependencies) = graph.get(node) {
526 for dep in dependencies {
527 if !visited.contains(dep) {
528 self.topo_sort_util(dep, graph, visited, rec_stack, result)?;
529 } else if rec_stack.contains(dep) {
530 return Err(Error::generic(format!(
531 "Circular dependency detected involving '{}'",
532 node
533 )));
534 }
535 }
536 }
537
538 rec_stack.remove(node);
539 result.push(node.to_string());
540 Ok(())
541 }
542
543 fn collect_dependency_level(
545 &self,
546 request_id: String,
547 _graph: &HashMap<String, Vec<String>>,
548 level: &mut Vec<String>,
549 processed: &mut HashSet<String>,
550 ) {
551 level.push(request_id.clone());
552 processed.insert(request_id);
553 }
554
555 fn expand_template(&self, template: &str, context: &ChainTemplatingContext) -> String {
557 let templating_context = crate::templating::TemplatingContext {
558 chain_context: Some(context.clone()),
559 env_context: None,
560 virtual_clock: None,
561 };
562 crate::templating::expand_str_with_context(template, &templating_context)
563 }
564
565 fn expand_template_in_json(&self, value: &Value, context: &ChainTemplatingContext) -> Value {
567 match value {
568 Value::String(s) => Value::String(self.expand_template(s, context)),
569 Value::Array(arr) => {
570 Value::Array(arr.iter().map(|v| self.expand_template_in_json(v, context)).collect())
571 }
572 Value::Object(map) => {
573 let mut new_map = serde_json::Map::new();
574 for (k, v) in map {
575 new_map.insert(
576 self.expand_template(k, context),
577 self.expand_template_in_json(v, context),
578 );
579 }
580 Value::Object(new_map)
581 }
582 _ => value.clone(),
583 }
584 }
585
586 fn extract_from_response(&self, response: &ChainResponse, path: &str) -> Option<Value> {
588 let parts: Vec<&str> = path.split('.').collect();
589
590 if parts.is_empty() || parts[0] != "body" {
591 return None;
592 }
593
594 let mut current = response.body.as_ref()?;
595
596 for part in &parts[1..] {
597 match current {
598 Value::Object(map) => {
599 current = map.get(*part)?;
600 }
601 Value::Array(arr) => {
602 if part.starts_with('[') && part.ends_with(']') {
603 let index_str = &part[1..part.len() - 1];
604 if let Ok(index) = index_str.parse::<usize>() {
605 current = arr.get(index)?;
606 } else {
607 return None;
608 }
609 } else {
610 return None;
611 }
612 }
613 _ => return None,
614 }
615 }
616
617 Some(current.clone())
618 }
619}
620
621#[derive(Debug, Clone)]
623pub struct ChainExecutionResult {
624 pub chain_id: String,
626 pub status: ChainExecutionStatus,
628 pub total_duration_ms: u64,
630 pub request_results: HashMap<String, ChainResponse>,
632 pub error_message: Option<String>,
634}
635
636#[derive(Debug, Clone, PartialEq)]
638pub enum ChainExecutionStatus {
639 Successful,
641 PartialSuccess,
643 Failed,
645}
646
647#[cfg(test)]
648mod tests {
649 use super::*;
650 use std::sync::Arc;
651
652 #[tokio::test]
653 async fn test_engine_creation() {
654 let registry = Arc::new(RequestChainRegistry::new(ChainConfig::default()));
655 let _engine = ChainExecutionEngine::new(registry, ChainConfig::default());
656
657 }
659
660 #[tokio::test]
661 async fn test_topological_sort() {
662 let registry = Arc::new(RequestChainRegistry::new(ChainConfig::default()));
663 let engine = ChainExecutionEngine::new(registry, ChainConfig::default());
664
665 let mut graph = HashMap::new();
666 graph.insert("A".to_string(), vec![]);
667 graph.insert("B".to_string(), vec!["A".to_string()]);
668 graph.insert("C".to_string(), vec!["A".to_string()]);
669 graph.insert("D".to_string(), vec!["B".to_string(), "C".to_string()]);
670
671 let topo_order = engine.topological_sort(&graph).unwrap();
672
673 let d_pos = topo_order.iter().position(|x| x == "D").unwrap();
678 let b_pos = topo_order.iter().position(|x| x == "B").unwrap();
679 let c_pos = topo_order.iter().position(|x| x == "C").unwrap();
680 let a_pos = topo_order.iter().position(|x| x == "A").unwrap();
681
682 assert!(d_pos < b_pos, "D should come before B");
683 assert!(d_pos < c_pos, "D should come before C");
684 assert!(b_pos < a_pos, "B should come before A");
685 assert!(c_pos < a_pos, "C should come before A");
686 assert_eq!(topo_order.len(), 4, "Should have all 4 nodes");
687 }
688
689 #[tokio::test]
690 async fn test_circular_dependency_detection() {
691 let registry = Arc::new(RequestChainRegistry::new(ChainConfig::default()));
692 let engine = ChainExecutionEngine::new(registry, ChainConfig::default());
693
694 let mut graph = HashMap::new();
695 graph.insert("A".to_string(), vec!["B".to_string()]);
696 graph.insert("B".to_string(), vec!["A".to_string()]); let result = engine.topological_sort(&graph);
699 assert!(result.is_err());
700 }
701}