mockforge_core/
chain_execution.rs

1//! Chain execution engine for request chaining
2//!
3//! This module provides the execution engine that manages chain execution with
4//! dependency resolution, parallel execution when possible, and proper error handling.
5
6use crate::request_chaining::{
7    ChainConfig, ChainDefinition, ChainExecutionContext, ChainLink, ChainResponse,
8    ChainTemplatingContext, RequestChainRegistry,
9};
10use crate::request_scripting::{ScriptContext, ScriptEngine};
11use crate::templating::{expand_str_with_context, TemplatingContext};
12use crate::{Error, Result};
13use chrono::Utc;
14use futures::future::join_all;
15use reqwest::{
16    header::{HeaderMap, HeaderName, HeaderValue},
17    Client, Method,
18};
19use serde_json::Value;
20use std::collections::{HashMap, HashSet};
21use std::str::FromStr;
22use std::sync::Arc;
23use tokio::sync::Mutex;
24use tokio::time::{timeout, Duration};
25
26/// Record of a chain execution with timestamp
27#[derive(Debug, Clone)]
28pub struct ExecutionRecord {
29    /// ISO 8601 timestamp when the chain was executed
30    pub executed_at: String,
31    /// Result of the chain execution
32    pub result: ChainExecutionResult,
33}
34
35/// Engine for executing request chains
36#[derive(Debug)]
37pub struct ChainExecutionEngine {
38    /// HTTP client for making requests
39    http_client: Client,
40    /// Chain registry
41    registry: Arc<RequestChainRegistry>,
42    /// Global configuration
43    config: ChainConfig,
44    /// Execution history storage (chain_id -> Vec<ExecutionRecord>)
45    execution_history: Arc<Mutex<HashMap<String, Vec<ExecutionRecord>>>>,
46    /// JavaScript scripting engine for pre/post request scripts
47    script_engine: ScriptEngine,
48}
49
50impl ChainExecutionEngine {
51    /// Create a new chain execution engine
52    /// 
53    /// # Panics
54    /// 
55    /// This method will panic if the HTTP client cannot be created, which typically
56    /// indicates a system configuration issue. For better error handling, use `try_new()`.
57    pub fn new(registry: Arc<RequestChainRegistry>, config: ChainConfig) -> Self {
58        Self::try_new(registry, config)
59            .unwrap_or_else(|e| {
60                panic!(
61                    "Failed to create HTTP client for chain execution engine: {}. \
62                    This typically indicates a system configuration issue (e.g., invalid timeout value).",
63                    e
64                )
65            })
66    }
67
68    /// Try to create a new chain execution engine
69    /// 
70    /// Returns an error if the HTTP client cannot be created.
71    pub fn try_new(registry: Arc<RequestChainRegistry>, config: ChainConfig) -> Result<Self> {
72        let http_client = Client::builder()
73            .timeout(Duration::from_secs(config.global_timeout_secs))
74            .build()
75            .map_err(|e| Error::generic(format!(
76                "Failed to create HTTP client: {}. \
77                Check that the timeout value ({}) is valid.",
78                e, config.global_timeout_secs
79            )))?;
80
81        Ok(Self {
82            http_client,
83            registry,
84            config,
85            execution_history: Arc::new(Mutex::new(HashMap::new())),
86            script_engine: ScriptEngine::new(),
87        })
88    }
89
90    /// Execute a chain by ID
91    pub async fn execute_chain(
92        &self,
93        chain_id: &str,
94        variables: Option<serde_json::Value>,
95    ) -> Result<ChainExecutionResult> {
96        let chain = self
97            .registry
98            .get_chain(chain_id)
99            .await
100            .ok_or_else(|| Error::generic(format!("Chain '{}' not found", chain_id)))?;
101
102        let result = self.execute_chain_definition(&chain, variables).await?;
103
104        // Store execution in history
105        let record = ExecutionRecord {
106            executed_at: Utc::now().to_rfc3339(),
107            result: result.clone(),
108        };
109
110        let mut history = self.execution_history.lock().await;
111        history.entry(chain_id.to_string()).or_insert_with(Vec::new).push(record);
112
113        Ok(result)
114    }
115
116    /// Get execution history for a chain
117    pub async fn get_chain_history(&self, chain_id: &str) -> Vec<ExecutionRecord> {
118        let history = self.execution_history.lock().await;
119        history.get(chain_id).cloned().unwrap_or_default()
120    }
121
122    /// Execute a chain definition
123    pub async fn execute_chain_definition(
124        &self,
125        chain_def: &ChainDefinition,
126        variables: Option<serde_json::Value>,
127    ) -> Result<ChainExecutionResult> {
128        // First validate the chain
129        self.registry.validate_chain(chain_def).await?;
130
131        let start_time = std::time::Instant::now();
132        let mut execution_context = ChainExecutionContext::new(chain_def.clone());
133
134        // Initialize context with chain variables
135        for (key, value) in &chain_def.variables {
136            execution_context
137                .templating
138                .chain_context
139                .set_variable(key.clone(), value.clone());
140        }
141
142        // Merge custom variables from request
143        if let Some(serde_json::Value::Object(map)) = variables {
144            for (key, value) in map {
145                execution_context.templating.chain_context.set_variable(key, value);
146            }
147        }
148
149        if self.config.enable_parallel_execution {
150            self.execute_with_parallelism(&mut execution_context).await
151        } else {
152            self.execute_sequential(&mut execution_context).await
153        }
154        .map(|_| ChainExecutionResult {
155            chain_id: chain_def.id.clone(),
156            status: ChainExecutionStatus::Successful,
157            total_duration_ms: start_time.elapsed().as_millis() as u64,
158            request_results: execution_context.templating.chain_context.responses.clone(),
159            error_message: None,
160        })
161    }
162
163    /// Execute chain using topological sorting for parallelism
164    async fn execute_with_parallelism(
165        &self,
166        execution_context: &mut ChainExecutionContext,
167    ) -> Result<()> {
168        let dep_graph = self.build_dependency_graph(&execution_context.definition.links);
169        let topo_order = self.topological_sort(&dep_graph)?;
170
171        // Group requests by dependency level
172        let mut level_groups = vec![];
173        let mut processed = HashSet::new();
174
175        for request_id in topo_order {
176            if !processed.contains(&request_id) {
177                let mut level = vec![];
178                self.collect_dependency_level(request_id, &dep_graph, &mut level, &mut processed);
179                level_groups.push(level);
180            }
181        }
182
183        // Execute levels in parallel
184        for level in level_groups {
185            if level.len() == 1 {
186                // Single request, execute directly
187                let request_id = &level[0];
188                let link = execution_context
189                    .definition
190                    .links
191                    .iter()
192                    .find(|l| l.request.id == *request_id)
193                    .unwrap();
194
195                let link_clone = link.clone();
196                self.execute_request(&link_clone, execution_context).await?;
197            } else {
198                // Execute level in parallel
199                let tasks = level
200                    .into_iter()
201                    .map(|request_id| {
202                        let link = execution_context
203                            .definition
204                            .links
205                            .iter()
206                            .find(|l| l.request.id == request_id)
207                            .unwrap()
208                            .clone();
209                        // Create a new context for parallel execution
210                        let parallel_context = ChainExecutionContext {
211                            definition: execution_context.definition.clone(),
212                            templating: execution_context.templating.clone(),
213                            start_time: std::time::Instant::now(),
214                            config: execution_context.config.clone(),
215                        };
216
217                        let context = Arc::new(Mutex::new(parallel_context));
218                        let engine =
219                            ChainExecutionEngine::new(self.registry.clone(), self.config.clone());
220
221                        tokio::spawn(async move {
222                            let mut ctx = context.lock().await;
223                            engine.execute_request(&link, &mut ctx).await
224                        })
225                    })
226                    .collect::<Vec<_>>();
227
228                let results = join_all(tasks).await;
229                for result in results {
230                    result
231                        .map_err(|e| Error::generic(format!("Task join error: {}", e)))?
232                        .map_err(|e| Error::generic(format!("Request execution error: {}", e)))?;
233                }
234            }
235        }
236
237        Ok(())
238    }
239
240    /// Execute requests sequentially
241    async fn execute_sequential(
242        &self,
243        execution_context: &mut ChainExecutionContext,
244    ) -> Result<()> {
245        let links = execution_context.definition.links.clone();
246        for link in &links {
247            self.execute_request(link, execution_context).await?;
248        }
249        Ok(())
250    }
251
252    /// Execute a single request in the chain
253    async fn execute_request(
254        &self,
255        link: &ChainLink,
256        execution_context: &mut ChainExecutionContext,
257    ) -> Result<()> {
258        let request_start = std::time::Instant::now();
259
260        // Prepare the request with templating
261        execution_context.templating.set_current_request(link.request.clone());
262
263        let method = Method::from_bytes(link.request.method.as_bytes()).map_err(|e| {
264            Error::generic(format!("Invalid HTTP method '{}': {}", link.request.method, e))
265        })?;
266
267        let url = self.expand_template(&link.request.url, &execution_context.templating);
268
269        // Prepare headers
270        let mut headers = HeaderMap::new();
271        for (key, value) in &link.request.headers {
272            let expanded_value = self.expand_template(value, &execution_context.templating);
273            let header_name = HeaderName::from_str(key)
274                .map_err(|e| Error::generic(format!("Invalid header name '{}': {}", key, e)))?;
275            let header_value = HeaderValue::from_str(&expanded_value).map_err(|e| {
276                Error::generic(format!("Invalid header value for '{}': {}", key, e))
277            })?;
278            headers.insert(header_name, header_value);
279        }
280
281        // Prepare request builder
282        let mut request_builder = self.http_client.request(method, &url).headers(headers.clone());
283
284        // Add body if present
285        if let Some(body) = &link.request.body {
286            match body {
287                crate::request_chaining::RequestBody::Json(json_value) => {
288                    let expanded_body =
289                        self.expand_template_in_json(json_value, &execution_context.templating);
290                    request_builder = request_builder.json(&expanded_body);
291                }
292                crate::request_chaining::RequestBody::BinaryFile { path, content_type } => {
293                    // Create templating context for path expansion
294                    let templating_context =
295                        TemplatingContext::with_chain(execution_context.templating.clone());
296
297                    // Expand templates in the file path
298                    let expanded_path = expand_str_with_context(path, &templating_context);
299
300                    // Create a new body with expanded path
301                    let binary_body = crate::request_chaining::RequestBody::binary_file(
302                        expanded_path,
303                        content_type.clone(),
304                    );
305
306                    // Read the binary file
307                    match binary_body.to_bytes().await {
308                        Ok(file_bytes) => {
309                            request_builder = request_builder.body(file_bytes);
310
311                            // Set content type if specified
312                            if let Some(ct) = content_type {
313                                let mut headers = headers.clone();
314                                headers.insert(
315                                    "content-type",
316                                    ct.parse().unwrap_or_else(|_| {
317                                        reqwest::header::HeaderValue::from_static(
318                                            "application/octet-stream",
319                                        )
320                                    }),
321                                );
322                                request_builder = request_builder.headers(headers);
323                            }
324                        }
325                        Err(e) => {
326                            return Err(e);
327                        }
328                    }
329                }
330            }
331        }
332
333        // Set timeout if specified
334        if let Some(timeout_secs) = link.request.timeout_secs {
335            request_builder = request_builder.timeout(Duration::from_secs(timeout_secs));
336        }
337
338        // Execute pre-request script if configured
339        if let Some(scripting) = &link.request.scripting {
340            if let Some(pre_script) = &scripting.pre_script {
341                let script_context = ScriptContext {
342                    request: Some(link.request.clone()),
343                    response: None,
344                    chain_context: execution_context.templating.chain_context.variables.clone(),
345                    variables: HashMap::new(),
346                    env_vars: std::env::vars().collect(),
347                };
348
349                match self
350                    .script_engine
351                    .execute_script(pre_script, &script_context, scripting.timeout_ms)
352                    .await
353                {
354                    Ok(script_result) => {
355                        // Merge script-modified variables into chain context
356                        for (key, value) in script_result.modified_variables {
357                            execution_context.templating.chain_context.set_variable(key, value);
358                        }
359                    }
360                    Err(e) => {
361                        tracing::warn!(
362                            "Pre-script execution failed for request '{}': {}",
363                            link.request.id,
364                            e
365                        );
366                        // Continue execution even if script fails
367                    }
368                }
369            }
370        }
371
372        // Execute the request
373        let response_result =
374            timeout(Duration::from_secs(self.config.global_timeout_secs), request_builder.send())
375                .await;
376
377        let response = match response_result {
378            Ok(Ok(resp)) => resp,
379            Ok(Err(e)) => {
380                return Err(Error::generic(format!("Request '{}' failed: {}", link.request.id, e)));
381            }
382            Err(_) => {
383                return Err(Error::generic(format!("Request '{}' timed out", link.request.id)));
384            }
385        };
386
387        let status = response.status();
388        let headers: HashMap<String, String> = response
389            .headers()
390            .iter()
391            .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
392            .collect();
393
394        let body_text = response.text().await.unwrap_or_default();
395        let body_json: Option<Value> = serde_json::from_str(&body_text).ok();
396
397        let duration_ms = request_start.elapsed().as_millis() as u64;
398        let executed_at = Utc::now().to_rfc3339();
399
400        let chain_response = ChainResponse {
401            status: status.as_u16(),
402            headers,
403            body: body_json,
404            duration_ms,
405            executed_at,
406            error: None,
407        };
408
409        // Validate expected status if specified
410        if let Some(expected) = &link.request.expected_status {
411            if !expected.contains(&status.as_u16()) {
412                let error_msg = format!(
413                    "Request '{}' returned status {} but expected one of {:?}",
414                    link.request.id,
415                    status.as_u16(),
416                    expected
417                );
418                return Err(Error::generic(error_msg));
419            }
420        }
421
422        // Store the response
423        if let Some(store_name) = &link.store_as {
424            execution_context
425                .templating
426                .chain_context
427                .store_response(store_name.clone(), chain_response.clone());
428        }
429
430        // Extract variables from response
431        for (var_name, extraction_path) in &link.extract {
432            if let Some(value) = self.extract_from_response(&chain_response, extraction_path) {
433                execution_context.templating.chain_context.set_variable(var_name.clone(), value);
434            }
435        }
436
437        // Execute post-request script if configured
438        if let Some(scripting) = &link.request.scripting {
439            if let Some(post_script) = &scripting.post_script {
440                let script_context = ScriptContext {
441                    request: Some(link.request.clone()),
442                    response: Some(chain_response.clone()),
443                    chain_context: execution_context.templating.chain_context.variables.clone(),
444                    variables: HashMap::new(),
445                    env_vars: std::env::vars().collect(),
446                };
447
448                match self
449                    .script_engine
450                    .execute_script(post_script, &script_context, scripting.timeout_ms)
451                    .await
452                {
453                    Ok(script_result) => {
454                        // Merge script-modified variables into chain context
455                        for (key, value) in script_result.modified_variables {
456                            execution_context.templating.chain_context.set_variable(key, value);
457                        }
458                    }
459                    Err(e) => {
460                        tracing::warn!(
461                            "Post-script execution failed for request '{}': {}",
462                            link.request.id,
463                            e
464                        );
465                        // Continue execution even if script fails
466                    }
467                }
468            }
469        }
470
471        // Also store by request ID as fallback
472        execution_context
473            .templating
474            .chain_context
475            .store_response(link.request.id.clone(), chain_response);
476
477        Ok(())
478    }
479
480    /// Build dependency graph from chain links
481    fn build_dependency_graph(&self, links: &[ChainLink]) -> HashMap<String, Vec<String>> {
482        let mut graph = HashMap::new();
483
484        for link in links {
485            graph
486                .entry(link.request.id.clone())
487                .or_insert_with(Vec::new)
488                .extend(link.request.depends_on.iter().cloned());
489        }
490
491        graph
492    }
493
494    /// Perform topological sort on dependency graph
495    fn topological_sort(&self, graph: &HashMap<String, Vec<String>>) -> Result<Vec<String>> {
496        let mut visited = HashSet::new();
497        let mut rec_stack = HashSet::new();
498        let mut result = Vec::new();
499
500        for node in graph.keys() {
501            if !visited.contains(node) {
502                self.topo_sort_util(node, graph, &mut visited, &mut rec_stack, &mut result)?;
503            }
504        }
505
506        result.reverse();
507        Ok(result)
508    }
509
510    /// Utility function for topological sort
511    #[allow(clippy::only_used_in_recursion)]
512    fn topo_sort_util(
513        &self,
514        node: &str,
515        graph: &HashMap<String, Vec<String>>,
516        visited: &mut HashSet<String>,
517        rec_stack: &mut HashSet<String>,
518        result: &mut Vec<String>,
519    ) -> Result<()> {
520        visited.insert(node.to_string());
521        rec_stack.insert(node.to_string());
522
523        if let Some(dependencies) = graph.get(node) {
524            for dep in dependencies {
525                if !visited.contains(dep) {
526                    self.topo_sort_util(dep, graph, visited, rec_stack, result)?;
527                } else if rec_stack.contains(dep) {
528                    return Err(Error::generic(format!(
529                        "Circular dependency detected involving '{}'",
530                        node
531                    )));
532                }
533            }
534        }
535
536        rec_stack.remove(node);
537        result.push(node.to_string());
538        Ok(())
539    }
540
541    /// Collect requests that can be executed in parallel (same dependency level)
542    fn collect_dependency_level(
543        &self,
544        request_id: String,
545        _graph: &HashMap<String, Vec<String>>,
546        level: &mut Vec<String>,
547        processed: &mut HashSet<String>,
548    ) {
549        level.push(request_id.clone());
550        processed.insert(request_id);
551    }
552
553    /// Expand template string with chain context
554    fn expand_template(&self, template: &str, context: &ChainTemplatingContext) -> String {
555        let templating_context = crate::templating::TemplatingContext {
556            chain_context: Some(context.clone()),
557            env_context: None,
558            virtual_clock: None,
559        };
560        crate::templating::expand_str_with_context(template, &templating_context)
561    }
562
563    /// Expand template variables in JSON value
564    fn expand_template_in_json(&self, value: &Value, context: &ChainTemplatingContext) -> Value {
565        match value {
566            Value::String(s) => Value::String(self.expand_template(s, context)),
567            Value::Array(arr) => {
568                Value::Array(arr.iter().map(|v| self.expand_template_in_json(v, context)).collect())
569            }
570            Value::Object(map) => {
571                let mut new_map = serde_json::Map::new();
572                for (k, v) in map {
573                    new_map.insert(
574                        self.expand_template(k, context),
575                        self.expand_template_in_json(v, context),
576                    );
577                }
578                Value::Object(new_map)
579            }
580            _ => value.clone(),
581        }
582    }
583
584    /// Extract value from response using JSON path-like syntax
585    fn extract_from_response(&self, response: &ChainResponse, path: &str) -> Option<Value> {
586        let parts: Vec<&str> = path.split('.').collect();
587
588        if parts.is_empty() || parts[0] != "body" {
589            return None;
590        }
591
592        let mut current = response.body.as_ref()?;
593
594        for part in &parts[1..] {
595            match current {
596                Value::Object(map) => {
597                    current = map.get(*part)?;
598                }
599                Value::Array(arr) => {
600                    if part.starts_with('[') && part.ends_with(']') {
601                        let index_str = &part[1..part.len() - 1];
602                        if let Ok(index) = index_str.parse::<usize>() {
603                            current = arr.get(index)?;
604                        } else {
605                            return None;
606                        }
607                    } else {
608                        return None;
609                    }
610                }
611                _ => return None,
612            }
613        }
614
615        Some(current.clone())
616    }
617}
618
619/// Result of executing a request chain
620#[derive(Debug, Clone)]
621pub struct ChainExecutionResult {
622    /// Unique identifier for the executed chain
623    pub chain_id: String,
624    /// Overall execution status
625    pub status: ChainExecutionStatus,
626    /// Total duration of chain execution in milliseconds
627    pub total_duration_ms: u64,
628    /// Results of individual requests in the chain, keyed by request ID
629    pub request_results: HashMap<String, ChainResponse>,
630    /// Error message if execution failed
631    pub error_message: Option<String>,
632}
633
634/// Status of chain execution
635#[derive(Debug, Clone, PartialEq)]
636pub enum ChainExecutionStatus {
637    /// All requests in the chain succeeded
638    Successful,
639    /// Some requests succeeded but others failed
640    PartialSuccess,
641    /// Chain execution failed
642    Failed,
643}
644
645#[cfg(test)]
646mod tests {
647    use super::*;
648    use std::sync::Arc;
649
650    #[tokio::test]
651    async fn test_engine_creation() {
652        let registry = Arc::new(RequestChainRegistry::new(ChainConfig::default()));
653        let _engine = ChainExecutionEngine::new(registry, ChainConfig::default());
654
655        // Engine should be created successfully
656    }
657
658    #[tokio::test]
659    async fn test_topological_sort() {
660        let registry = Arc::new(RequestChainRegistry::new(ChainConfig::default()));
661        let engine = ChainExecutionEngine::new(registry, ChainConfig::default());
662
663        let mut graph = HashMap::new();
664        graph.insert("A".to_string(), vec![]);
665        graph.insert("B".to_string(), vec!["A".to_string()]);
666        graph.insert("C".to_string(), vec!["A".to_string()]);
667        graph.insert("D".to_string(), vec!["B".to_string(), "C".to_string()]);
668
669        let topo_order = engine.topological_sort(&graph).unwrap();
670
671        // Verify this is a valid topological ordering
672        // D should come before B and C (its dependencies)
673        // B should come before A (its dependency)
674        // C should come before A (its dependency)
675        let d_pos = topo_order.iter().position(|x| x == "D").unwrap();
676        let b_pos = topo_order.iter().position(|x| x == "B").unwrap();
677        let c_pos = topo_order.iter().position(|x| x == "C").unwrap();
678        let a_pos = topo_order.iter().position(|x| x == "A").unwrap();
679
680        assert!(d_pos < b_pos, "D should come before B");
681        assert!(d_pos < c_pos, "D should come before C");
682        assert!(b_pos < a_pos, "B should come before A");
683        assert!(c_pos < a_pos, "C should come before A");
684        assert_eq!(topo_order.len(), 4, "Should have all 4 nodes");
685    }
686
687    #[tokio::test]
688    async fn test_circular_dependency_detection() {
689        let registry = Arc::new(RequestChainRegistry::new(ChainConfig::default()));
690        let engine = ChainExecutionEngine::new(registry, ChainConfig::default());
691
692        let mut graph = HashMap::new();
693        graph.insert("A".to_string(), vec!["B".to_string()]);
694        graph.insert("B".to_string(), vec!["A".to_string()]); // Circular dependency
695
696        let result = engine.topological_sort(&graph);
697        assert!(result.is_err());
698    }
699}