1use crate::request_chaining::{
7 ChainConfig, ChainDefinition, ChainExecutionContext, ChainLink, ChainResponse,
8 ChainTemplatingContext, RequestChainRegistry,
9};
10use crate::templating::{expand_str_with_context, TemplatingContext};
11use crate::{Error, Result};
12use chrono::Utc;
13use futures::future::join_all;
14use reqwest::{
15 header::{HeaderMap, HeaderName, HeaderValue},
16 Client, Method,
17};
18use serde_json::Value;
19use std::collections::{HashMap, HashSet};
20use std::str::FromStr;
21use std::sync::Arc;
22use tokio::sync::Mutex;
23use tokio::time::{timeout, Duration};
24
25#[derive(Debug, Clone)]
27pub struct ExecutionRecord {
28 pub executed_at: String,
29 pub result: ChainExecutionResult,
30}
31
32#[derive(Debug)]
34pub struct ChainExecutionEngine {
35 http_client: Client,
37 registry: Arc<RequestChainRegistry>,
39 config: ChainConfig,
41 execution_history: Arc<Mutex<HashMap<String, Vec<ExecutionRecord>>>>,
43}
44
45impl ChainExecutionEngine {
46 pub fn new(registry: Arc<RequestChainRegistry>, config: ChainConfig) -> Self {
48 let http_client = Client::builder()
49 .timeout(Duration::from_secs(config.global_timeout_secs))
50 .build()
51 .expect("Failed to create HTTP client");
52
53 Self {
54 http_client,
55 registry,
56 config,
57 execution_history: Arc::new(Mutex::new(HashMap::new())),
58 }
59 }
60
61 pub async fn execute_chain(
63 &self,
64 chain_id: &str,
65 variables: Option<serde_json::Value>,
66 ) -> Result<ChainExecutionResult> {
67 let chain = self
68 .registry
69 .get_chain(chain_id)
70 .await
71 .ok_or_else(|| Error::generic(format!("Chain '{}' not found", chain_id)))?;
72
73 let result = self.execute_chain_definition(&chain, variables).await?;
74
75 let record = ExecutionRecord {
77 executed_at: Utc::now().to_rfc3339(),
78 result: result.clone(),
79 };
80
81 let mut history = self.execution_history.lock().await;
82 history.entry(chain_id.to_string()).or_insert_with(Vec::new).push(record);
83
84 Ok(result)
85 }
86
87 pub async fn get_chain_history(&self, chain_id: &str) -> Vec<ExecutionRecord> {
89 let history = self.execution_history.lock().await;
90 history.get(chain_id).cloned().unwrap_or_default()
91 }
92
93 pub async fn execute_chain_definition(
95 &self,
96 chain_def: &ChainDefinition,
97 variables: Option<serde_json::Value>,
98 ) -> Result<ChainExecutionResult> {
99 self.registry.validate_chain(chain_def).await?;
101
102 let start_time = std::time::Instant::now();
103 let mut execution_context = ChainExecutionContext::new(chain_def.clone());
104
105 for (key, value) in &chain_def.variables {
107 execution_context
108 .templating
109 .chain_context
110 .set_variable(key.clone(), value.clone());
111 }
112
113 if let Some(serde_json::Value::Object(map)) = variables {
115 for (key, value) in map {
116 execution_context.templating.chain_context.set_variable(key, value);
117 }
118 }
119
120 if self.config.enable_parallel_execution {
121 self.execute_with_parallelism(&mut execution_context).await
122 } else {
123 self.execute_sequential(&mut execution_context).await
124 }
125 .map(|_| ChainExecutionResult {
126 chain_id: chain_def.id.clone(),
127 status: ChainExecutionStatus::Successful,
128 total_duration_ms: start_time.elapsed().as_millis() as u64,
129 request_results: execution_context.templating.chain_context.responses.clone(),
130 error_message: None,
131 })
132 }
133
134 async fn execute_with_parallelism(
136 &self,
137 execution_context: &mut ChainExecutionContext,
138 ) -> Result<()> {
139 let dep_graph = self.build_dependency_graph(&execution_context.definition.links);
140 let topo_order = self.topological_sort(&dep_graph)?;
141
142 let mut level_groups = vec![];
144 let mut processed = HashSet::new();
145
146 for request_id in topo_order {
147 if !processed.contains(&request_id) {
148 let mut level = vec![];
149 self.collect_dependency_level(request_id, &dep_graph, &mut level, &mut processed);
150 level_groups.push(level);
151 }
152 }
153
154 for level in level_groups {
156 if level.len() == 1 {
157 let request_id = &level[0];
159 let link = execution_context
160 .definition
161 .links
162 .iter()
163 .find(|l| l.request.id == *request_id)
164 .unwrap();
165
166 let link_clone = link.clone();
167 self.execute_request(&link_clone, execution_context).await?;
168 } else {
169 let tasks = level
171 .into_iter()
172 .map(|request_id| {
173 let link = execution_context
174 .definition
175 .links
176 .iter()
177 .find(|l| l.request.id == request_id)
178 .unwrap()
179 .clone();
180 let parallel_context = ChainExecutionContext {
182 definition: execution_context.definition.clone(),
183 templating: execution_context.templating.clone(),
184 start_time: std::time::Instant::now(),
185 config: execution_context.config.clone(),
186 };
187
188 let context = Arc::new(Mutex::new(parallel_context));
189 let engine =
190 ChainExecutionEngine::new(self.registry.clone(), self.config.clone());
191
192 tokio::spawn(async move {
193 let mut ctx = context.lock().await;
194 engine.execute_request(&link, &mut ctx).await
195 })
196 })
197 .collect::<Vec<_>>();
198
199 let results = join_all(tasks).await;
200 for result in results {
201 result
202 .map_err(|e| Error::generic(format!("Task join error: {}", e)))?
203 .map_err(|e| Error::generic(format!("Request execution error: {}", e)))?;
204 }
205 }
206 }
207
208 Ok(())
209 }
210
211 async fn execute_sequential(
213 &self,
214 execution_context: &mut ChainExecutionContext,
215 ) -> Result<()> {
216 let links = execution_context.definition.links.clone();
217 for link in &links {
218 self.execute_request(link, execution_context).await?;
219 }
220 Ok(())
221 }
222
223 async fn execute_request(
225 &self,
226 link: &ChainLink,
227 execution_context: &mut ChainExecutionContext,
228 ) -> Result<()> {
229 let request_start = std::time::Instant::now();
230
231 execution_context.templating.set_current_request(link.request.clone());
233
234 let method = Method::from_bytes(link.request.method.as_bytes()).map_err(|e| {
235 Error::generic(format!("Invalid HTTP method '{}': {}", link.request.method, e))
236 })?;
237
238 let url = self.expand_template(&link.request.url, &execution_context.templating);
239
240 let mut headers = HeaderMap::new();
242 for (key, value) in &link.request.headers {
243 let expanded_value = self.expand_template(value, &execution_context.templating);
244 let header_name = HeaderName::from_str(key)
245 .map_err(|e| Error::generic(format!("Invalid header name '{}': {}", key, e)))?;
246 let header_value = HeaderValue::from_str(&expanded_value).map_err(|e| {
247 Error::generic(format!("Invalid header value for '{}': {}", key, e))
248 })?;
249 headers.insert(header_name, header_value);
250 }
251
252 let mut request_builder = self.http_client.request(method, &url).headers(headers.clone());
254
255 if let Some(body) = &link.request.body {
257 match body {
258 crate::request_chaining::RequestBody::Json(json_value) => {
259 let expanded_body =
260 self.expand_template_in_json(json_value, &execution_context.templating);
261 request_builder = request_builder.json(&expanded_body);
262 }
263 crate::request_chaining::RequestBody::BinaryFile { path, content_type } => {
264 let templating_context =
266 TemplatingContext::with_chain(execution_context.templating.clone());
267
268 let expanded_path = expand_str_with_context(path, &templating_context);
270
271 let binary_body = crate::request_chaining::RequestBody::binary_file(
273 expanded_path,
274 content_type.clone(),
275 );
276
277 match binary_body.to_bytes().await {
279 Ok(file_bytes) => {
280 request_builder = request_builder.body(file_bytes);
281
282 if let Some(ct) = content_type {
284 let mut headers = headers.clone();
285 headers.insert(
286 "content-type",
287 ct.parse().unwrap_or_else(|_| {
288 reqwest::header::HeaderValue::from_static(
289 "application/octet-stream",
290 )
291 }),
292 );
293 request_builder = request_builder.headers(headers);
294 }
295 }
296 Err(e) => {
297 return Err(e);
298 }
299 }
300 }
301 }
302 }
303
304 if let Some(timeout_secs) = link.request.timeout_secs {
306 request_builder = request_builder.timeout(Duration::from_secs(timeout_secs));
307 }
308
309 let response_result =
311 timeout(Duration::from_secs(self.config.global_timeout_secs), request_builder.send())
312 .await;
313
314 let response = match response_result {
315 Ok(Ok(resp)) => resp,
316 Ok(Err(e)) => {
317 return Err(Error::generic(format!("Request '{}' failed: {}", link.request.id, e)));
318 }
319 Err(_) => {
320 return Err(Error::generic(format!("Request '{}' timed out", link.request.id)));
321 }
322 };
323
324 let status = response.status();
325 let headers: HashMap<String, String> = response
326 .headers()
327 .iter()
328 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
329 .collect();
330
331 let body_text = response.text().await.unwrap_or_default();
332 let body_json: Option<Value> = serde_json::from_str(&body_text).ok();
333
334 let duration_ms = request_start.elapsed().as_millis() as u64;
335 let executed_at = Utc::now().to_rfc3339();
336
337 let chain_response = ChainResponse {
338 status: status.as_u16(),
339 headers,
340 body: body_json,
341 duration_ms,
342 executed_at,
343 error: None,
344 };
345
346 if let Some(expected) = &link.request.expected_status {
348 if !expected.contains(&status.as_u16()) {
349 let error_msg = format!(
350 "Request '{}' returned status {} but expected one of {:?}",
351 link.request.id,
352 status.as_u16(),
353 expected
354 );
355 return Err(Error::generic(error_msg));
356 }
357 }
358
359 if let Some(store_name) = &link.store_as {
361 execution_context
362 .templating
363 .chain_context
364 .store_response(store_name.clone(), chain_response.clone());
365 }
366
367 for (var_name, extraction_path) in &link.extract {
369 if let Some(value) = self.extract_from_response(&chain_response, extraction_path) {
370 execution_context.templating.chain_context.set_variable(var_name.clone(), value);
371 }
372 }
373
374 execution_context
376 .templating
377 .chain_context
378 .store_response(link.request.id.clone(), chain_response);
379
380 Ok(())
381 }
382
383 fn build_dependency_graph(&self, links: &[ChainLink]) -> HashMap<String, Vec<String>> {
385 let mut graph = HashMap::new();
386
387 for link in links {
388 graph
389 .entry(link.request.id.clone())
390 .or_insert_with(Vec::new)
391 .extend(link.request.depends_on.iter().cloned());
392 }
393
394 graph
395 }
396
397 fn topological_sort(&self, graph: &HashMap<String, Vec<String>>) -> Result<Vec<String>> {
399 let mut visited = HashSet::new();
400 let mut rec_stack = HashSet::new();
401 let mut result = Vec::new();
402
403 for node in graph.keys() {
404 if !visited.contains(node) {
405 self.topo_sort_util(node, graph, &mut visited, &mut rec_stack, &mut result)?;
406 }
407 }
408
409 result.reverse();
410 Ok(result)
411 }
412
413 #[allow(clippy::only_used_in_recursion)]
415 fn topo_sort_util(
416 &self,
417 node: &str,
418 graph: &HashMap<String, Vec<String>>,
419 visited: &mut HashSet<String>,
420 rec_stack: &mut HashSet<String>,
421 result: &mut Vec<String>,
422 ) -> Result<()> {
423 visited.insert(node.to_string());
424 rec_stack.insert(node.to_string());
425
426 if let Some(dependencies) = graph.get(node) {
427 for dep in dependencies {
428 if !visited.contains(dep) {
429 self.topo_sort_util(dep, graph, visited, rec_stack, result)?;
430 } else if rec_stack.contains(dep) {
431 return Err(Error::generic(format!(
432 "Circular dependency detected involving '{}'",
433 node
434 )));
435 }
436 }
437 }
438
439 rec_stack.remove(node);
440 result.push(node.to_string());
441 Ok(())
442 }
443
444 fn collect_dependency_level(
446 &self,
447 request_id: String,
448 _graph: &HashMap<String, Vec<String>>,
449 level: &mut Vec<String>,
450 processed: &mut HashSet<String>,
451 ) {
452 level.push(request_id.clone());
453 processed.insert(request_id);
454 }
455
456 fn expand_template(&self, template: &str, context: &ChainTemplatingContext) -> String {
458 let templating_context = crate::templating::TemplatingContext {
459 chain_context: Some(context.clone()),
460 env_context: None,
461 virtual_clock: None,
462 };
463 crate::templating::expand_str_with_context(template, &templating_context)
464 }
465
466 fn expand_template_in_json(&self, value: &Value, context: &ChainTemplatingContext) -> Value {
468 match value {
469 Value::String(s) => Value::String(self.expand_template(s, context)),
470 Value::Array(arr) => {
471 Value::Array(arr.iter().map(|v| self.expand_template_in_json(v, context)).collect())
472 }
473 Value::Object(map) => {
474 let mut new_map = serde_json::Map::new();
475 for (k, v) in map {
476 new_map.insert(
477 self.expand_template(k, context),
478 self.expand_template_in_json(v, context),
479 );
480 }
481 Value::Object(new_map)
482 }
483 _ => value.clone(),
484 }
485 }
486
487 fn extract_from_response(&self, response: &ChainResponse, path: &str) -> Option<Value> {
489 let parts: Vec<&str> = path.split('.').collect();
490
491 if parts.is_empty() || parts[0] != "body" {
492 return None;
493 }
494
495 let mut current = response.body.as_ref()?;
496
497 for part in &parts[1..] {
498 match current {
499 Value::Object(map) => {
500 current = map.get(*part)?;
501 }
502 Value::Array(arr) => {
503 if part.starts_with('[') && part.ends_with(']') {
504 let index_str = &part[1..part.len() - 1];
505 if let Ok(index) = index_str.parse::<usize>() {
506 current = arr.get(index)?;
507 } else {
508 return None;
509 }
510 } else {
511 return None;
512 }
513 }
514 _ => return None,
515 }
516 }
517
518 Some(current.clone())
519 }
520}
521
522#[derive(Debug, Clone)]
524pub struct ChainExecutionResult {
525 pub chain_id: String,
526 pub status: ChainExecutionStatus,
527 pub total_duration_ms: u64,
528 pub request_results: HashMap<String, ChainResponse>,
529 pub error_message: Option<String>,
530}
531
532#[derive(Debug, Clone, PartialEq)]
534pub enum ChainExecutionStatus {
535 Successful,
536 PartialSuccess,
537 Failed,
538}
539
540#[cfg(test)]
541mod tests {
542 use super::*;
543 use std::sync::Arc;
544
545 #[tokio::test]
546 async fn test_engine_creation() {
547 let registry = Arc::new(RequestChainRegistry::new(ChainConfig::default()));
548 let _engine = ChainExecutionEngine::new(registry, ChainConfig::default());
549
550 }
552
553 #[tokio::test]
554 async fn test_topological_sort() {
555 let registry = Arc::new(RequestChainRegistry::new(ChainConfig::default()));
556 let engine = ChainExecutionEngine::new(registry, ChainConfig::default());
557
558 let mut graph = HashMap::new();
559 graph.insert("A".to_string(), vec![]);
560 graph.insert("B".to_string(), vec!["A".to_string()]);
561 graph.insert("C".to_string(), vec!["A".to_string()]);
562 graph.insert("D".to_string(), vec!["B".to_string(), "C".to_string()]);
563
564 let topo_order = engine.topological_sort(&graph).unwrap();
565
566 let d_pos = topo_order.iter().position(|x| x == "D").unwrap();
571 let b_pos = topo_order.iter().position(|x| x == "B").unwrap();
572 let c_pos = topo_order.iter().position(|x| x == "C").unwrap();
573 let a_pos = topo_order.iter().position(|x| x == "A").unwrap();
574
575 assert!(d_pos < b_pos, "D should come before B");
576 assert!(d_pos < c_pos, "D should come before C");
577 assert!(b_pos < a_pos, "B should come before A");
578 assert!(c_pos < a_pos, "C should come before A");
579 assert_eq!(topo_order.len(), 4, "Should have all 4 nodes");
580 }
581
582 #[tokio::test]
583 async fn test_circular_dependency_detection() {
584 let registry = Arc::new(RequestChainRegistry::new(ChainConfig::default()));
585 let engine = ChainExecutionEngine::new(registry, ChainConfig::default());
586
587 let mut graph = HashMap::new();
588 graph.insert("A".to_string(), vec!["B".to_string()]);
589 graph.insert("B".to_string(), vec!["A".to_string()]); let result = engine.topological_sort(&graph);
592 assert!(result.is_err());
593 }
594}