1use crate::request_chaining::{
7 ChainConfig, ChainDefinition, ChainExecutionContext, ChainLink, ChainResponse,
8 ChainTemplatingContext, RequestChainRegistry,
9};
10use crate::request_scripting::{ScriptContext, ScriptEngine};
11use crate::templating::{expand_str_with_context, TemplatingContext};
12use crate::{Error, Result};
13use chrono::Utc;
14use futures::future::join_all;
15use reqwest::{
16 header::{HeaderMap, HeaderName, HeaderValue},
17 Client, Method,
18};
19use serde_json::Value;
20use std::collections::{HashMap, HashSet};
21use std::str::FromStr;
22use std::sync::Arc;
23use tokio::sync::Mutex;
24use tokio::time::{timeout, Duration};
25
26#[derive(Debug, Clone)]
28pub struct ExecutionRecord {
29 pub executed_at: String,
31 pub result: ChainExecutionResult,
33}
34
35#[derive(Debug)]
37pub struct ChainExecutionEngine {
38 http_client: Client,
40 registry: Arc<RequestChainRegistry>,
42 config: ChainConfig,
44 execution_history: Arc<Mutex<HashMap<String, Vec<ExecutionRecord>>>>,
46 script_engine: ScriptEngine,
48}
49
50impl ChainExecutionEngine {
51 pub fn new(registry: Arc<RequestChainRegistry>, config: ChainConfig) -> Self {
58 Self::try_new(registry, config)
59 .unwrap_or_else(|e| {
60 panic!(
61 "Failed to create HTTP client for chain execution engine: {}. \
62 This typically indicates a system configuration issue (e.g., invalid timeout value).",
63 e
64 )
65 })
66 }
67
68 pub fn try_new(registry: Arc<RequestChainRegistry>, config: ChainConfig) -> Result<Self> {
72 let http_client = Client::builder()
73 .timeout(Duration::from_secs(config.global_timeout_secs))
74 .build()
75 .map_err(|e| Error::generic(format!(
76 "Failed to create HTTP client: {}. \
77 Check that the timeout value ({}) is valid.",
78 e, config.global_timeout_secs
79 )))?;
80
81 Ok(Self {
82 http_client,
83 registry,
84 config,
85 execution_history: Arc::new(Mutex::new(HashMap::new())),
86 script_engine: ScriptEngine::new(),
87 })
88 }
89
90 pub async fn execute_chain(
92 &self,
93 chain_id: &str,
94 variables: Option<serde_json::Value>,
95 ) -> Result<ChainExecutionResult> {
96 let chain = self
97 .registry
98 .get_chain(chain_id)
99 .await
100 .ok_or_else(|| Error::generic(format!("Chain '{}' not found", chain_id)))?;
101
102 let result = self.execute_chain_definition(&chain, variables).await?;
103
104 let record = ExecutionRecord {
106 executed_at: Utc::now().to_rfc3339(),
107 result: result.clone(),
108 };
109
110 let mut history = self.execution_history.lock().await;
111 history.entry(chain_id.to_string()).or_insert_with(Vec::new).push(record);
112
113 Ok(result)
114 }
115
116 pub async fn get_chain_history(&self, chain_id: &str) -> Vec<ExecutionRecord> {
118 let history = self.execution_history.lock().await;
119 history.get(chain_id).cloned().unwrap_or_default()
120 }
121
122 pub async fn execute_chain_definition(
124 &self,
125 chain_def: &ChainDefinition,
126 variables: Option<serde_json::Value>,
127 ) -> Result<ChainExecutionResult> {
128 self.registry.validate_chain(chain_def).await?;
130
131 let start_time = std::time::Instant::now();
132 let mut execution_context = ChainExecutionContext::new(chain_def.clone());
133
134 for (key, value) in &chain_def.variables {
136 execution_context
137 .templating
138 .chain_context
139 .set_variable(key.clone(), value.clone());
140 }
141
142 if let Some(serde_json::Value::Object(map)) = variables {
144 for (key, value) in map {
145 execution_context.templating.chain_context.set_variable(key, value);
146 }
147 }
148
149 if self.config.enable_parallel_execution {
150 self.execute_with_parallelism(&mut execution_context).await
151 } else {
152 self.execute_sequential(&mut execution_context).await
153 }
154 .map(|_| ChainExecutionResult {
155 chain_id: chain_def.id.clone(),
156 status: ChainExecutionStatus::Successful,
157 total_duration_ms: start_time.elapsed().as_millis() as u64,
158 request_results: execution_context.templating.chain_context.responses.clone(),
159 error_message: None,
160 })
161 }
162
163 async fn execute_with_parallelism(
165 &self,
166 execution_context: &mut ChainExecutionContext,
167 ) -> Result<()> {
168 let dep_graph = self.build_dependency_graph(&execution_context.definition.links);
169 let topo_order = self.topological_sort(&dep_graph)?;
170
171 let mut level_groups = vec![];
173 let mut processed = HashSet::new();
174
175 for request_id in topo_order {
176 if !processed.contains(&request_id) {
177 let mut level = vec![];
178 self.collect_dependency_level(request_id, &dep_graph, &mut level, &mut processed);
179 level_groups.push(level);
180 }
181 }
182
183 for level in level_groups {
185 if level.len() == 1 {
186 let request_id = &level[0];
188 let link = execution_context
189 .definition
190 .links
191 .iter()
192 .find(|l| l.request.id == *request_id)
193 .unwrap();
194
195 let link_clone = link.clone();
196 self.execute_request(&link_clone, execution_context).await?;
197 } else {
198 let tasks = level
200 .into_iter()
201 .map(|request_id| {
202 let link = execution_context
203 .definition
204 .links
205 .iter()
206 .find(|l| l.request.id == request_id)
207 .unwrap()
208 .clone();
209 let parallel_context = ChainExecutionContext {
211 definition: execution_context.definition.clone(),
212 templating: execution_context.templating.clone(),
213 start_time: std::time::Instant::now(),
214 config: execution_context.config.clone(),
215 };
216
217 let context = Arc::new(Mutex::new(parallel_context));
218 let engine =
219 ChainExecutionEngine::new(self.registry.clone(), self.config.clone());
220
221 tokio::spawn(async move {
222 let mut ctx = context.lock().await;
223 engine.execute_request(&link, &mut ctx).await
224 })
225 })
226 .collect::<Vec<_>>();
227
228 let results = join_all(tasks).await;
229 for result in results {
230 result
231 .map_err(|e| Error::generic(format!("Task join error: {}", e)))?
232 .map_err(|e| Error::generic(format!("Request execution error: {}", e)))?;
233 }
234 }
235 }
236
237 Ok(())
238 }
239
240 async fn execute_sequential(
242 &self,
243 execution_context: &mut ChainExecutionContext,
244 ) -> Result<()> {
245 let links = execution_context.definition.links.clone();
246 for link in &links {
247 self.execute_request(link, execution_context).await?;
248 }
249 Ok(())
250 }
251
252 async fn execute_request(
254 &self,
255 link: &ChainLink,
256 execution_context: &mut ChainExecutionContext,
257 ) -> Result<()> {
258 let request_start = std::time::Instant::now();
259
260 execution_context.templating.set_current_request(link.request.clone());
262
263 let method = Method::from_bytes(link.request.method.as_bytes()).map_err(|e| {
264 Error::generic(format!("Invalid HTTP method '{}': {}", link.request.method, e))
265 })?;
266
267 let url = self.expand_template(&link.request.url, &execution_context.templating);
268
269 let mut headers = HeaderMap::new();
271 for (key, value) in &link.request.headers {
272 let expanded_value = self.expand_template(value, &execution_context.templating);
273 let header_name = HeaderName::from_str(key)
274 .map_err(|e| Error::generic(format!("Invalid header name '{}': {}", key, e)))?;
275 let header_value = HeaderValue::from_str(&expanded_value).map_err(|e| {
276 Error::generic(format!("Invalid header value for '{}': {}", key, e))
277 })?;
278 headers.insert(header_name, header_value);
279 }
280
281 let mut request_builder = self.http_client.request(method, &url).headers(headers.clone());
283
284 if let Some(body) = &link.request.body {
286 match body {
287 crate::request_chaining::RequestBody::Json(json_value) => {
288 let expanded_body =
289 self.expand_template_in_json(json_value, &execution_context.templating);
290 request_builder = request_builder.json(&expanded_body);
291 }
292 crate::request_chaining::RequestBody::BinaryFile { path, content_type } => {
293 let templating_context =
295 TemplatingContext::with_chain(execution_context.templating.clone());
296
297 let expanded_path = expand_str_with_context(path, &templating_context);
299
300 let binary_body = crate::request_chaining::RequestBody::binary_file(
302 expanded_path,
303 content_type.clone(),
304 );
305
306 match binary_body.to_bytes().await {
308 Ok(file_bytes) => {
309 request_builder = request_builder.body(file_bytes);
310
311 if let Some(ct) = content_type {
313 let mut headers = headers.clone();
314 headers.insert(
315 "content-type",
316 ct.parse().unwrap_or_else(|_| {
317 reqwest::header::HeaderValue::from_static(
318 "application/octet-stream",
319 )
320 }),
321 );
322 request_builder = request_builder.headers(headers);
323 }
324 }
325 Err(e) => {
326 return Err(e);
327 }
328 }
329 }
330 }
331 }
332
333 if let Some(timeout_secs) = link.request.timeout_secs {
335 request_builder = request_builder.timeout(Duration::from_secs(timeout_secs));
336 }
337
338 if let Some(scripting) = &link.request.scripting {
340 if let Some(pre_script) = &scripting.pre_script {
341 let script_context = ScriptContext {
342 request: Some(link.request.clone()),
343 response: None,
344 chain_context: execution_context.templating.chain_context.variables.clone(),
345 variables: HashMap::new(),
346 env_vars: std::env::vars().collect(),
347 };
348
349 match self
350 .script_engine
351 .execute_script(pre_script, &script_context, scripting.timeout_ms)
352 .await
353 {
354 Ok(script_result) => {
355 for (key, value) in script_result.modified_variables {
357 execution_context.templating.chain_context.set_variable(key, value);
358 }
359 }
360 Err(e) => {
361 tracing::warn!(
362 "Pre-script execution failed for request '{}': {}",
363 link.request.id,
364 e
365 );
366 }
368 }
369 }
370 }
371
372 let response_result =
374 timeout(Duration::from_secs(self.config.global_timeout_secs), request_builder.send())
375 .await;
376
377 let response = match response_result {
378 Ok(Ok(resp)) => resp,
379 Ok(Err(e)) => {
380 return Err(Error::generic(format!("Request '{}' failed: {}", link.request.id, e)));
381 }
382 Err(_) => {
383 return Err(Error::generic(format!("Request '{}' timed out", link.request.id)));
384 }
385 };
386
387 let status = response.status();
388 let headers: HashMap<String, String> = response
389 .headers()
390 .iter()
391 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
392 .collect();
393
394 let body_text = response.text().await.unwrap_or_default();
395 let body_json: Option<Value> = serde_json::from_str(&body_text).ok();
396
397 let duration_ms = request_start.elapsed().as_millis() as u64;
398 let executed_at = Utc::now().to_rfc3339();
399
400 let chain_response = ChainResponse {
401 status: status.as_u16(),
402 headers,
403 body: body_json,
404 duration_ms,
405 executed_at,
406 error: None,
407 };
408
409 if let Some(expected) = &link.request.expected_status {
411 if !expected.contains(&status.as_u16()) {
412 let error_msg = format!(
413 "Request '{}' returned status {} but expected one of {:?}",
414 link.request.id,
415 status.as_u16(),
416 expected
417 );
418 return Err(Error::generic(error_msg));
419 }
420 }
421
422 if let Some(store_name) = &link.store_as {
424 execution_context
425 .templating
426 .chain_context
427 .store_response(store_name.clone(), chain_response.clone());
428 }
429
430 for (var_name, extraction_path) in &link.extract {
432 if let Some(value) = self.extract_from_response(&chain_response, extraction_path) {
433 execution_context.templating.chain_context.set_variable(var_name.clone(), value);
434 }
435 }
436
437 if let Some(scripting) = &link.request.scripting {
439 if let Some(post_script) = &scripting.post_script {
440 let script_context = ScriptContext {
441 request: Some(link.request.clone()),
442 response: Some(chain_response.clone()),
443 chain_context: execution_context.templating.chain_context.variables.clone(),
444 variables: HashMap::new(),
445 env_vars: std::env::vars().collect(),
446 };
447
448 match self
449 .script_engine
450 .execute_script(post_script, &script_context, scripting.timeout_ms)
451 .await
452 {
453 Ok(script_result) => {
454 for (key, value) in script_result.modified_variables {
456 execution_context.templating.chain_context.set_variable(key, value);
457 }
458 }
459 Err(e) => {
460 tracing::warn!(
461 "Post-script execution failed for request '{}': {}",
462 link.request.id,
463 e
464 );
465 }
467 }
468 }
469 }
470
471 execution_context
473 .templating
474 .chain_context
475 .store_response(link.request.id.clone(), chain_response);
476
477 Ok(())
478 }
479
480 fn build_dependency_graph(&self, links: &[ChainLink]) -> HashMap<String, Vec<String>> {
482 let mut graph = HashMap::new();
483
484 for link in links {
485 graph
486 .entry(link.request.id.clone())
487 .or_insert_with(Vec::new)
488 .extend(link.request.depends_on.iter().cloned());
489 }
490
491 graph
492 }
493
494 fn topological_sort(&self, graph: &HashMap<String, Vec<String>>) -> Result<Vec<String>> {
496 let mut visited = HashSet::new();
497 let mut rec_stack = HashSet::new();
498 let mut result = Vec::new();
499
500 for node in graph.keys() {
501 if !visited.contains(node) {
502 self.topo_sort_util(node, graph, &mut visited, &mut rec_stack, &mut result)?;
503 }
504 }
505
506 result.reverse();
507 Ok(result)
508 }
509
510 #[allow(clippy::only_used_in_recursion)]
512 fn topo_sort_util(
513 &self,
514 node: &str,
515 graph: &HashMap<String, Vec<String>>,
516 visited: &mut HashSet<String>,
517 rec_stack: &mut HashSet<String>,
518 result: &mut Vec<String>,
519 ) -> Result<()> {
520 visited.insert(node.to_string());
521 rec_stack.insert(node.to_string());
522
523 if let Some(dependencies) = graph.get(node) {
524 for dep in dependencies {
525 if !visited.contains(dep) {
526 self.topo_sort_util(dep, graph, visited, rec_stack, result)?;
527 } else if rec_stack.contains(dep) {
528 return Err(Error::generic(format!(
529 "Circular dependency detected involving '{}'",
530 node
531 )));
532 }
533 }
534 }
535
536 rec_stack.remove(node);
537 result.push(node.to_string());
538 Ok(())
539 }
540
541 fn collect_dependency_level(
543 &self,
544 request_id: String,
545 _graph: &HashMap<String, Vec<String>>,
546 level: &mut Vec<String>,
547 processed: &mut HashSet<String>,
548 ) {
549 level.push(request_id.clone());
550 processed.insert(request_id);
551 }
552
553 fn expand_template(&self, template: &str, context: &ChainTemplatingContext) -> String {
555 let templating_context = crate::templating::TemplatingContext {
556 chain_context: Some(context.clone()),
557 env_context: None,
558 virtual_clock: None,
559 };
560 crate::templating::expand_str_with_context(template, &templating_context)
561 }
562
563 fn expand_template_in_json(&self, value: &Value, context: &ChainTemplatingContext) -> Value {
565 match value {
566 Value::String(s) => Value::String(self.expand_template(s, context)),
567 Value::Array(arr) => {
568 Value::Array(arr.iter().map(|v| self.expand_template_in_json(v, context)).collect())
569 }
570 Value::Object(map) => {
571 let mut new_map = serde_json::Map::new();
572 for (k, v) in map {
573 new_map.insert(
574 self.expand_template(k, context),
575 self.expand_template_in_json(v, context),
576 );
577 }
578 Value::Object(new_map)
579 }
580 _ => value.clone(),
581 }
582 }
583
584 fn extract_from_response(&self, response: &ChainResponse, path: &str) -> Option<Value> {
586 let parts: Vec<&str> = path.split('.').collect();
587
588 if parts.is_empty() || parts[0] != "body" {
589 return None;
590 }
591
592 let mut current = response.body.as_ref()?;
593
594 for part in &parts[1..] {
595 match current {
596 Value::Object(map) => {
597 current = map.get(*part)?;
598 }
599 Value::Array(arr) => {
600 if part.starts_with('[') && part.ends_with(']') {
601 let index_str = &part[1..part.len() - 1];
602 if let Ok(index) = index_str.parse::<usize>() {
603 current = arr.get(index)?;
604 } else {
605 return None;
606 }
607 } else {
608 return None;
609 }
610 }
611 _ => return None,
612 }
613 }
614
615 Some(current.clone())
616 }
617}
618
619#[derive(Debug, Clone)]
621pub struct ChainExecutionResult {
622 pub chain_id: String,
624 pub status: ChainExecutionStatus,
626 pub total_duration_ms: u64,
628 pub request_results: HashMap<String, ChainResponse>,
630 pub error_message: Option<String>,
632}
633
634#[derive(Debug, Clone, PartialEq)]
636pub enum ChainExecutionStatus {
637 Successful,
639 PartialSuccess,
641 Failed,
643}
644
645#[cfg(test)]
646mod tests {
647 use super::*;
648 use std::sync::Arc;
649
650 #[tokio::test]
651 async fn test_engine_creation() {
652 let registry = Arc::new(RequestChainRegistry::new(ChainConfig::default()));
653 let _engine = ChainExecutionEngine::new(registry, ChainConfig::default());
654
655 }
657
658 #[tokio::test]
659 async fn test_topological_sort() {
660 let registry = Arc::new(RequestChainRegistry::new(ChainConfig::default()));
661 let engine = ChainExecutionEngine::new(registry, ChainConfig::default());
662
663 let mut graph = HashMap::new();
664 graph.insert("A".to_string(), vec![]);
665 graph.insert("B".to_string(), vec!["A".to_string()]);
666 graph.insert("C".to_string(), vec!["A".to_string()]);
667 graph.insert("D".to_string(), vec!["B".to_string(), "C".to_string()]);
668
669 let topo_order = engine.topological_sort(&graph).unwrap();
670
671 let d_pos = topo_order.iter().position(|x| x == "D").unwrap();
676 let b_pos = topo_order.iter().position(|x| x == "B").unwrap();
677 let c_pos = topo_order.iter().position(|x| x == "C").unwrap();
678 let a_pos = topo_order.iter().position(|x| x == "A").unwrap();
679
680 assert!(d_pos < b_pos, "D should come before B");
681 assert!(d_pos < c_pos, "D should come before C");
682 assert!(b_pos < a_pos, "B should come before A");
683 assert!(c_pos < a_pos, "C should come before A");
684 assert_eq!(topo_order.len(), 4, "Should have all 4 nodes");
685 }
686
687 #[tokio::test]
688 async fn test_circular_dependency_detection() {
689 let registry = Arc::new(RequestChainRegistry::new(ChainConfig::default()));
690 let engine = ChainExecutionEngine::new(registry, ChainConfig::default());
691
692 let mut graph = HashMap::new();
693 graph.insert("A".to_string(), vec!["B".to_string()]);
694 graph.insert("B".to_string(), vec!["A".to_string()]); let result = engine.topological_sort(&graph);
697 assert!(result.is_err());
698 }
699}