mockforge_core/
chain_execution.rs

1//! Chain execution engine for request chaining
2//!
3//! This module provides the execution engine that manages chain execution with
4//! dependency resolution, parallel execution when possible, and proper error handling.
5
6use crate::request_chaining::{
7    ChainConfig, ChainDefinition, ChainExecutionContext, ChainLink, ChainResponse,
8    ChainTemplatingContext, RequestChainRegistry,
9};
10use crate::request_scripting::{ScriptContext, ScriptEngine};
11use crate::templating::{expand_str_with_context, TemplatingContext};
12use crate::{Error, Result};
13use chrono::Utc;
14use futures::future::join_all;
15use reqwest::{
16    header::{HeaderMap, HeaderName, HeaderValue},
17    Client, Method,
18};
19use serde_json::Value;
20use std::collections::{HashMap, HashSet};
21use std::str::FromStr;
22use std::sync::Arc;
23use tokio::sync::Mutex;
24use tokio::time::{timeout, Duration};
25
26/// Record of a chain execution with timestamp
27#[derive(Debug, Clone)]
28pub struct ExecutionRecord {
29    /// ISO 8601 timestamp when the chain was executed
30    pub executed_at: String,
31    /// Result of the chain execution
32    pub result: ChainExecutionResult,
33}
34
35/// Engine for executing request chains
36#[derive(Debug)]
37pub struct ChainExecutionEngine {
38    /// HTTP client for making requests
39    http_client: Client,
40    /// Chain registry
41    registry: Arc<RequestChainRegistry>,
42    /// Global configuration
43    config: ChainConfig,
44    /// Execution history storage (chain_id -> Vec<ExecutionRecord>)
45    execution_history: Arc<Mutex<HashMap<String, Vec<ExecutionRecord>>>>,
46    /// JavaScript scripting engine for pre/post request scripts
47    script_engine: ScriptEngine,
48}
49
50impl ChainExecutionEngine {
51    /// Create a new chain execution engine
52    ///
53    /// # Panics
54    ///
55    /// This method will panic if the HTTP client cannot be created, which typically
56    /// indicates a system configuration issue. For better error handling, use `try_new()`.
57    pub fn new(registry: Arc<RequestChainRegistry>, config: ChainConfig) -> Self {
58        Self::try_new(registry, config)
59            .unwrap_or_else(|e| {
60                panic!(
61                    "Failed to create HTTP client for chain execution engine: {}. \
62                    This typically indicates a system configuration issue (e.g., invalid timeout value).",
63                    e
64                )
65            })
66    }
67
68    /// Try to create a new chain execution engine
69    ///
70    /// Returns an error if the HTTP client cannot be created.
71    pub fn try_new(registry: Arc<RequestChainRegistry>, config: ChainConfig) -> Result<Self> {
72        let http_client = Client::builder()
73            .timeout(Duration::from_secs(config.global_timeout_secs))
74            .build()
75            .map_err(|e| {
76                Error::generic(format!(
77                    "Failed to create HTTP client: {}. \
78                Check that the timeout value ({}) is valid.",
79                    e, config.global_timeout_secs
80                ))
81            })?;
82
83        Ok(Self {
84            http_client,
85            registry,
86            config,
87            execution_history: Arc::new(Mutex::new(HashMap::new())),
88            script_engine: ScriptEngine::new(),
89        })
90    }
91
92    /// Execute a chain by ID
93    pub async fn execute_chain(
94        &self,
95        chain_id: &str,
96        variables: Option<serde_json::Value>,
97    ) -> Result<ChainExecutionResult> {
98        let chain = self
99            .registry
100            .get_chain(chain_id)
101            .await
102            .ok_or_else(|| Error::generic(format!("Chain '{}' not found", chain_id)))?;
103
104        let result = self.execute_chain_definition(&chain, variables).await?;
105
106        // Store execution in history
107        let record = ExecutionRecord {
108            executed_at: Utc::now().to_rfc3339(),
109            result: result.clone(),
110        };
111
112        let mut history = self.execution_history.lock().await;
113        history.entry(chain_id.to_string()).or_insert_with(Vec::new).push(record);
114
115        Ok(result)
116    }
117
118    /// Get execution history for a chain
119    pub async fn get_chain_history(&self, chain_id: &str) -> Vec<ExecutionRecord> {
120        let history = self.execution_history.lock().await;
121        history.get(chain_id).cloned().unwrap_or_default()
122    }
123
124    /// Execute a chain definition
125    pub async fn execute_chain_definition(
126        &self,
127        chain_def: &ChainDefinition,
128        variables: Option<serde_json::Value>,
129    ) -> Result<ChainExecutionResult> {
130        // First validate the chain
131        self.registry.validate_chain(chain_def).await?;
132
133        let start_time = std::time::Instant::now();
134        let mut execution_context = ChainExecutionContext::new(chain_def.clone());
135
136        // Initialize context with chain variables
137        for (key, value) in &chain_def.variables {
138            execution_context
139                .templating
140                .chain_context
141                .set_variable(key.clone(), value.clone());
142        }
143
144        // Merge custom variables from request
145        if let Some(serde_json::Value::Object(map)) = variables {
146            for (key, value) in map {
147                execution_context.templating.chain_context.set_variable(key, value);
148            }
149        }
150
151        if self.config.enable_parallel_execution {
152            self.execute_with_parallelism(&mut execution_context).await
153        } else {
154            self.execute_sequential(&mut execution_context).await
155        }
156        .map(|_| ChainExecutionResult {
157            chain_id: chain_def.id.clone(),
158            status: ChainExecutionStatus::Successful,
159            total_duration_ms: start_time.elapsed().as_millis() as u64,
160            request_results: execution_context.templating.chain_context.responses.clone(),
161            error_message: None,
162        })
163    }
164
165    /// Execute chain using topological sorting for parallelism
166    async fn execute_with_parallelism(
167        &self,
168        execution_context: &mut ChainExecutionContext,
169    ) -> Result<()> {
170        let dep_graph = self.build_dependency_graph(&execution_context.definition.links);
171        let topo_order = self.topological_sort(&dep_graph)?;
172
173        // Group requests by dependency level
174        let mut level_groups = vec![];
175        let mut processed = HashSet::new();
176
177        for request_id in topo_order {
178            if !processed.contains(&request_id) {
179                let mut level = vec![];
180                self.collect_dependency_level(request_id, &dep_graph, &mut level, &mut processed);
181                level_groups.push(level);
182            }
183        }
184
185        // Execute levels in parallel
186        for level in level_groups {
187            if level.len() == 1 {
188                // Single request, execute directly
189                let request_id = &level[0];
190                let link = execution_context
191                    .definition
192                    .links
193                    .iter()
194                    .find(|l| l.request.id == *request_id)
195                    .unwrap();
196
197                let link_clone = link.clone();
198                self.execute_request(&link_clone, execution_context).await?;
199            } else {
200                // Execute level in parallel
201                let tasks = level
202                    .into_iter()
203                    .map(|request_id| {
204                        let link = execution_context
205                            .definition
206                            .links
207                            .iter()
208                            .find(|l| l.request.id == request_id)
209                            .unwrap()
210                            .clone();
211                        // Create a new context for parallel execution
212                        let parallel_context = ChainExecutionContext {
213                            definition: execution_context.definition.clone(),
214                            templating: execution_context.templating.clone(),
215                            start_time: std::time::Instant::now(),
216                            config: execution_context.config.clone(),
217                        };
218
219                        let context = Arc::new(Mutex::new(parallel_context));
220                        let engine =
221                            ChainExecutionEngine::new(self.registry.clone(), self.config.clone());
222
223                        tokio::spawn(async move {
224                            let mut ctx = context.lock().await;
225                            engine.execute_request(&link, &mut ctx).await
226                        })
227                    })
228                    .collect::<Vec<_>>();
229
230                let results = join_all(tasks).await;
231                for result in results {
232                    result
233                        .map_err(|e| Error::generic(format!("Task join error: {}", e)))?
234                        .map_err(|e| Error::generic(format!("Request execution error: {}", e)))?;
235                }
236            }
237        }
238
239        Ok(())
240    }
241
242    /// Execute requests sequentially
243    async fn execute_sequential(
244        &self,
245        execution_context: &mut ChainExecutionContext,
246    ) -> Result<()> {
247        let links = execution_context.definition.links.clone();
248        for link in &links {
249            self.execute_request(link, execution_context).await?;
250        }
251        Ok(())
252    }
253
254    /// Execute a single request in the chain
255    async fn execute_request(
256        &self,
257        link: &ChainLink,
258        execution_context: &mut ChainExecutionContext,
259    ) -> Result<()> {
260        let request_start = std::time::Instant::now();
261
262        // Prepare the request with templating
263        execution_context.templating.set_current_request(link.request.clone());
264
265        let method = Method::from_bytes(link.request.method.as_bytes()).map_err(|e| {
266            Error::generic(format!("Invalid HTTP method '{}': {}", link.request.method, e))
267        })?;
268
269        let url = self.expand_template(&link.request.url, &execution_context.templating);
270
271        // Prepare headers
272        let mut headers = HeaderMap::new();
273        for (key, value) in &link.request.headers {
274            let expanded_value = self.expand_template(value, &execution_context.templating);
275            let header_name = HeaderName::from_str(key)
276                .map_err(|e| Error::generic(format!("Invalid header name '{}': {}", key, e)))?;
277            let header_value = HeaderValue::from_str(&expanded_value).map_err(|e| {
278                Error::generic(format!("Invalid header value for '{}': {}", key, e))
279            })?;
280            headers.insert(header_name, header_value);
281        }
282
283        // Prepare request builder
284        let mut request_builder = self.http_client.request(method, &url).headers(headers.clone());
285
286        // Add body if present
287        if let Some(body) = &link.request.body {
288            match body {
289                crate::request_chaining::RequestBody::Json(json_value) => {
290                    let expanded_body =
291                        self.expand_template_in_json(json_value, &execution_context.templating);
292                    request_builder = request_builder.json(&expanded_body);
293                }
294                crate::request_chaining::RequestBody::BinaryFile { path, content_type } => {
295                    // Create templating context for path expansion
296                    let templating_context =
297                        TemplatingContext::with_chain(execution_context.templating.clone());
298
299                    // Expand templates in the file path
300                    let expanded_path = expand_str_with_context(path, &templating_context);
301
302                    // Create a new body with expanded path
303                    let binary_body = crate::request_chaining::RequestBody::binary_file(
304                        expanded_path,
305                        content_type.clone(),
306                    );
307
308                    // Read the binary file
309                    match binary_body.to_bytes().await {
310                        Ok(file_bytes) => {
311                            request_builder = request_builder.body(file_bytes);
312
313                            // Set content type if specified
314                            if let Some(ct) = content_type {
315                                let mut headers = headers.clone();
316                                headers.insert(
317                                    "content-type",
318                                    ct.parse().unwrap_or_else(|_| {
319                                        reqwest::header::HeaderValue::from_static(
320                                            "application/octet-stream",
321                                        )
322                                    }),
323                                );
324                                request_builder = request_builder.headers(headers);
325                            }
326                        }
327                        Err(e) => {
328                            return Err(e);
329                        }
330                    }
331                }
332            }
333        }
334
335        // Set timeout if specified
336        if let Some(timeout_secs) = link.request.timeout_secs {
337            request_builder = request_builder.timeout(Duration::from_secs(timeout_secs));
338        }
339
340        // Execute pre-request script if configured
341        if let Some(scripting) = &link.request.scripting {
342            if let Some(pre_script) = &scripting.pre_script {
343                let script_context = ScriptContext {
344                    request: Some(link.request.clone()),
345                    response: None,
346                    chain_context: execution_context.templating.chain_context.variables.clone(),
347                    variables: HashMap::new(),
348                    env_vars: std::env::vars().collect(),
349                };
350
351                match self
352                    .script_engine
353                    .execute_script(pre_script, &script_context, scripting.timeout_ms)
354                    .await
355                {
356                    Ok(script_result) => {
357                        // Merge script-modified variables into chain context
358                        for (key, value) in script_result.modified_variables {
359                            execution_context.templating.chain_context.set_variable(key, value);
360                        }
361                    }
362                    Err(e) => {
363                        tracing::warn!(
364                            "Pre-script execution failed for request '{}': {}",
365                            link.request.id,
366                            e
367                        );
368                        // Continue execution even if script fails
369                    }
370                }
371            }
372        }
373
374        // Execute the request
375        let response_result =
376            timeout(Duration::from_secs(self.config.global_timeout_secs), request_builder.send())
377                .await;
378
379        let response = match response_result {
380            Ok(Ok(resp)) => resp,
381            Ok(Err(e)) => {
382                return Err(Error::generic(format!("Request '{}' failed: {}", link.request.id, e)));
383            }
384            Err(_) => {
385                return Err(Error::generic(format!("Request '{}' timed out", link.request.id)));
386            }
387        };
388
389        let status = response.status();
390        let headers: HashMap<String, String> = response
391            .headers()
392            .iter()
393            .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
394            .collect();
395
396        let body_text = response.text().await.unwrap_or_default();
397        let body_json: Option<Value> = serde_json::from_str(&body_text).ok();
398
399        let duration_ms = request_start.elapsed().as_millis() as u64;
400        let executed_at = Utc::now().to_rfc3339();
401
402        let chain_response = ChainResponse {
403            status: status.as_u16(),
404            headers,
405            body: body_json,
406            duration_ms,
407            executed_at,
408            error: None,
409        };
410
411        // Validate expected status if specified
412        if let Some(expected) = &link.request.expected_status {
413            if !expected.contains(&status.as_u16()) {
414                let error_msg = format!(
415                    "Request '{}' returned status {} but expected one of {:?}",
416                    link.request.id,
417                    status.as_u16(),
418                    expected
419                );
420                return Err(Error::generic(error_msg));
421            }
422        }
423
424        // Store the response
425        if let Some(store_name) = &link.store_as {
426            execution_context
427                .templating
428                .chain_context
429                .store_response(store_name.clone(), chain_response.clone());
430        }
431
432        // Extract variables from response
433        for (var_name, extraction_path) in &link.extract {
434            if let Some(value) = self.extract_from_response(&chain_response, extraction_path) {
435                execution_context.templating.chain_context.set_variable(var_name.clone(), value);
436            }
437        }
438
439        // Execute post-request script if configured
440        if let Some(scripting) = &link.request.scripting {
441            if let Some(post_script) = &scripting.post_script {
442                let script_context = ScriptContext {
443                    request: Some(link.request.clone()),
444                    response: Some(chain_response.clone()),
445                    chain_context: execution_context.templating.chain_context.variables.clone(),
446                    variables: HashMap::new(),
447                    env_vars: std::env::vars().collect(),
448                };
449
450                match self
451                    .script_engine
452                    .execute_script(post_script, &script_context, scripting.timeout_ms)
453                    .await
454                {
455                    Ok(script_result) => {
456                        // Merge script-modified variables into chain context
457                        for (key, value) in script_result.modified_variables {
458                            execution_context.templating.chain_context.set_variable(key, value);
459                        }
460                    }
461                    Err(e) => {
462                        tracing::warn!(
463                            "Post-script execution failed for request '{}': {}",
464                            link.request.id,
465                            e
466                        );
467                        // Continue execution even if script fails
468                    }
469                }
470            }
471        }
472
473        // Also store by request ID as fallback
474        execution_context
475            .templating
476            .chain_context
477            .store_response(link.request.id.clone(), chain_response);
478
479        Ok(())
480    }
481
482    /// Build dependency graph from chain links
483    fn build_dependency_graph(&self, links: &[ChainLink]) -> HashMap<String, Vec<String>> {
484        let mut graph = HashMap::new();
485
486        for link in links {
487            graph
488                .entry(link.request.id.clone())
489                .or_insert_with(Vec::new)
490                .extend(link.request.depends_on.iter().cloned());
491        }
492
493        graph
494    }
495
496    /// Perform topological sort on dependency graph
497    fn topological_sort(&self, graph: &HashMap<String, Vec<String>>) -> Result<Vec<String>> {
498        let mut visited = HashSet::new();
499        let mut rec_stack = HashSet::new();
500        let mut result = Vec::new();
501
502        for node in graph.keys() {
503            if !visited.contains(node) {
504                self.topo_sort_util(node, graph, &mut visited, &mut rec_stack, &mut result)?;
505            }
506        }
507
508        result.reverse();
509        Ok(result)
510    }
511
512    /// Utility function for topological sort
513    #[allow(clippy::only_used_in_recursion)]
514    fn topo_sort_util(
515        &self,
516        node: &str,
517        graph: &HashMap<String, Vec<String>>,
518        visited: &mut HashSet<String>,
519        rec_stack: &mut HashSet<String>,
520        result: &mut Vec<String>,
521    ) -> Result<()> {
522        visited.insert(node.to_string());
523        rec_stack.insert(node.to_string());
524
525        if let Some(dependencies) = graph.get(node) {
526            for dep in dependencies {
527                if !visited.contains(dep) {
528                    self.topo_sort_util(dep, graph, visited, rec_stack, result)?;
529                } else if rec_stack.contains(dep) {
530                    return Err(Error::generic(format!(
531                        "Circular dependency detected involving '{}'",
532                        node
533                    )));
534                }
535            }
536        }
537
538        rec_stack.remove(node);
539        result.push(node.to_string());
540        Ok(())
541    }
542
543    /// Collect requests that can be executed in parallel (same dependency level)
544    fn collect_dependency_level(
545        &self,
546        request_id: String,
547        _graph: &HashMap<String, Vec<String>>,
548        level: &mut Vec<String>,
549        processed: &mut HashSet<String>,
550    ) {
551        level.push(request_id.clone());
552        processed.insert(request_id);
553    }
554
555    /// Expand template string with chain context
556    fn expand_template(&self, template: &str, context: &ChainTemplatingContext) -> String {
557        let templating_context = crate::templating::TemplatingContext {
558            chain_context: Some(context.clone()),
559            env_context: None,
560            virtual_clock: None,
561        };
562        crate::templating::expand_str_with_context(template, &templating_context)
563    }
564
565    /// Expand template variables in JSON value
566    fn expand_template_in_json(&self, value: &Value, context: &ChainTemplatingContext) -> Value {
567        match value {
568            Value::String(s) => Value::String(self.expand_template(s, context)),
569            Value::Array(arr) => {
570                Value::Array(arr.iter().map(|v| self.expand_template_in_json(v, context)).collect())
571            }
572            Value::Object(map) => {
573                let mut new_map = serde_json::Map::new();
574                for (k, v) in map {
575                    new_map.insert(
576                        self.expand_template(k, context),
577                        self.expand_template_in_json(v, context),
578                    );
579                }
580                Value::Object(new_map)
581            }
582            _ => value.clone(),
583        }
584    }
585
586    /// Extract value from response using JSON path-like syntax
587    fn extract_from_response(&self, response: &ChainResponse, path: &str) -> Option<Value> {
588        let parts: Vec<&str> = path.split('.').collect();
589
590        if parts.is_empty() || parts[0] != "body" {
591            return None;
592        }
593
594        let mut current = response.body.as_ref()?;
595
596        for part in &parts[1..] {
597            match current {
598                Value::Object(map) => {
599                    current = map.get(*part)?;
600                }
601                Value::Array(arr) => {
602                    if part.starts_with('[') && part.ends_with(']') {
603                        let index_str = &part[1..part.len() - 1];
604                        if let Ok(index) = index_str.parse::<usize>() {
605                            current = arr.get(index)?;
606                        } else {
607                            return None;
608                        }
609                    } else {
610                        return None;
611                    }
612                }
613                _ => return None,
614            }
615        }
616
617        Some(current.clone())
618    }
619}
620
621/// Result of executing a request chain
622#[derive(Debug, Clone)]
623pub struct ChainExecutionResult {
624    /// Unique identifier for the executed chain
625    pub chain_id: String,
626    /// Overall execution status
627    pub status: ChainExecutionStatus,
628    /// Total duration of chain execution in milliseconds
629    pub total_duration_ms: u64,
630    /// Results of individual requests in the chain, keyed by request ID
631    pub request_results: HashMap<String, ChainResponse>,
632    /// Error message if execution failed
633    pub error_message: Option<String>,
634}
635
636/// Status of chain execution
637#[derive(Debug, Clone, PartialEq)]
638pub enum ChainExecutionStatus {
639    /// All requests in the chain succeeded
640    Successful,
641    /// Some requests succeeded but others failed
642    PartialSuccess,
643    /// Chain execution failed
644    Failed,
645}
646
647#[cfg(test)]
648mod tests {
649    use super::*;
650    use std::sync::Arc;
651
652    #[tokio::test]
653    async fn test_engine_creation() {
654        let registry = Arc::new(RequestChainRegistry::new(ChainConfig::default()));
655        let _engine = ChainExecutionEngine::new(registry, ChainConfig::default());
656
657        // Engine should be created successfully
658    }
659
660    #[tokio::test]
661    async fn test_topological_sort() {
662        let registry = Arc::new(RequestChainRegistry::new(ChainConfig::default()));
663        let engine = ChainExecutionEngine::new(registry, ChainConfig::default());
664
665        let mut graph = HashMap::new();
666        graph.insert("A".to_string(), vec![]);
667        graph.insert("B".to_string(), vec!["A".to_string()]);
668        graph.insert("C".to_string(), vec!["A".to_string()]);
669        graph.insert("D".to_string(), vec!["B".to_string(), "C".to_string()]);
670
671        let topo_order = engine.topological_sort(&graph).unwrap();
672
673        // Verify this is a valid topological ordering
674        // D should come before B and C (its dependencies)
675        // B should come before A (its dependency)
676        // C should come before A (its dependency)
677        let d_pos = topo_order.iter().position(|x| x == "D").unwrap();
678        let b_pos = topo_order.iter().position(|x| x == "B").unwrap();
679        let c_pos = topo_order.iter().position(|x| x == "C").unwrap();
680        let a_pos = topo_order.iter().position(|x| x == "A").unwrap();
681
682        assert!(d_pos < b_pos, "D should come before B");
683        assert!(d_pos < c_pos, "D should come before C");
684        assert!(b_pos < a_pos, "B should come before A");
685        assert!(c_pos < a_pos, "C should come before A");
686        assert_eq!(topo_order.len(), 4, "Should have all 4 nodes");
687    }
688
689    #[tokio::test]
690    async fn test_circular_dependency_detection() {
691        let registry = Arc::new(RequestChainRegistry::new(ChainConfig::default()));
692        let engine = ChainExecutionEngine::new(registry, ChainConfig::default());
693
694        let mut graph = HashMap::new();
695        graph.insert("A".to_string(), vec!["B".to_string()]);
696        graph.insert("B".to_string(), vec!["A".to_string()]); // Circular dependency
697
698        let result = engine.topological_sort(&graph);
699        assert!(result.is_err());
700    }
701}