mockforge_core/
chain_execution.rs

1//! Chain execution engine for request chaining
2//!
3//! This module provides the execution engine that manages chain execution with
4//! dependency resolution, parallel execution when possible, and proper error handling.
5
6use crate::request_chaining::{
7    ChainConfig, ChainDefinition, ChainExecutionContext, ChainLink, ChainResponse,
8    ChainTemplatingContext, RequestChainRegistry,
9};
10use crate::templating::{expand_str_with_context, TemplatingContext};
11use crate::{Error, Result};
12use chrono::Utc;
13use futures::future::join_all;
14use reqwest::{
15    header::{HeaderMap, HeaderName, HeaderValue},
16    Client, Method,
17};
18use serde_json::Value;
19use std::collections::{HashMap, HashSet};
20use std::str::FromStr;
21use std::sync::Arc;
22use tokio::sync::Mutex;
23use tokio::time::{timeout, Duration};
24
25/// Record of a chain execution with timestamp
26#[derive(Debug, Clone)]
27pub struct ExecutionRecord {
28    pub executed_at: String,
29    pub result: ChainExecutionResult,
30}
31
32/// Engine for executing request chains
33#[derive(Debug)]
34pub struct ChainExecutionEngine {
35    /// HTTP client for making requests
36    http_client: Client,
37    /// Chain registry
38    registry: Arc<RequestChainRegistry>,
39    /// Global configuration
40    config: ChainConfig,
41    /// Execution history storage (chain_id -> Vec<ExecutionRecord>)
42    execution_history: Arc<Mutex<HashMap<String, Vec<ExecutionRecord>>>>,
43}
44
45impl ChainExecutionEngine {
46    /// Create a new chain execution engine
47    pub fn new(registry: Arc<RequestChainRegistry>, config: ChainConfig) -> Self {
48        let http_client = Client::builder()
49            .timeout(Duration::from_secs(config.global_timeout_secs))
50            .build()
51            .expect("Failed to create HTTP client");
52
53        Self {
54            http_client,
55            registry,
56            config,
57            execution_history: Arc::new(Mutex::new(HashMap::new())),
58        }
59    }
60
61    /// Execute a chain by ID
62    pub async fn execute_chain(
63        &self,
64        chain_id: &str,
65        variables: Option<serde_json::Value>,
66    ) -> Result<ChainExecutionResult> {
67        let chain = self
68            .registry
69            .get_chain(chain_id)
70            .await
71            .ok_or_else(|| Error::generic(format!("Chain '{}' not found", chain_id)))?;
72
73        let result = self.execute_chain_definition(&chain, variables).await?;
74
75        // Store execution in history
76        let record = ExecutionRecord {
77            executed_at: Utc::now().to_rfc3339(),
78            result: result.clone(),
79        };
80
81        let mut history = self.execution_history.lock().await;
82        history.entry(chain_id.to_string()).or_insert_with(Vec::new).push(record);
83
84        Ok(result)
85    }
86
87    /// Get execution history for a chain
88    pub async fn get_chain_history(&self, chain_id: &str) -> Vec<ExecutionRecord> {
89        let history = self.execution_history.lock().await;
90        history.get(chain_id).cloned().unwrap_or_default()
91    }
92
93    /// Execute a chain definition
94    pub async fn execute_chain_definition(
95        &self,
96        chain_def: &ChainDefinition,
97        variables: Option<serde_json::Value>,
98    ) -> Result<ChainExecutionResult> {
99        // First validate the chain
100        self.registry.validate_chain(chain_def).await?;
101
102        let start_time = std::time::Instant::now();
103        let mut execution_context = ChainExecutionContext::new(chain_def.clone());
104
105        // Initialize context with chain variables
106        for (key, value) in &chain_def.variables {
107            execution_context
108                .templating
109                .chain_context
110                .set_variable(key.clone(), value.clone());
111        }
112
113        // Merge custom variables from request
114        if let Some(serde_json::Value::Object(map)) = variables {
115            for (key, value) in map {
116                execution_context.templating.chain_context.set_variable(key, value);
117            }
118        }
119
120        if self.config.enable_parallel_execution {
121            self.execute_with_parallelism(&mut execution_context).await
122        } else {
123            self.execute_sequential(&mut execution_context).await
124        }
125        .map(|_| ChainExecutionResult {
126            chain_id: chain_def.id.clone(),
127            status: ChainExecutionStatus::Successful,
128            total_duration_ms: start_time.elapsed().as_millis() as u64,
129            request_results: execution_context.templating.chain_context.responses.clone(),
130            error_message: None,
131        })
132    }
133
134    /// Execute chain using topological sorting for parallelism
135    async fn execute_with_parallelism(
136        &self,
137        execution_context: &mut ChainExecutionContext,
138    ) -> Result<()> {
139        let dep_graph = self.build_dependency_graph(&execution_context.definition.links);
140        let topo_order = self.topological_sort(&dep_graph)?;
141
142        // Group requests by dependency level
143        let mut level_groups = vec![];
144        let mut processed = HashSet::new();
145
146        for request_id in topo_order {
147            if !processed.contains(&request_id) {
148                let mut level = vec![];
149                self.collect_dependency_level(request_id, &dep_graph, &mut level, &mut processed);
150                level_groups.push(level);
151            }
152        }
153
154        // Execute levels in parallel
155        for level in level_groups {
156            if level.len() == 1 {
157                // Single request, execute directly
158                let request_id = &level[0];
159                let link = execution_context
160                    .definition
161                    .links
162                    .iter()
163                    .find(|l| l.request.id == *request_id)
164                    .unwrap();
165
166                let link_clone = link.clone();
167                self.execute_request(&link_clone, execution_context).await?;
168            } else {
169                // Execute level in parallel
170                let tasks = level
171                    .into_iter()
172                    .map(|request_id| {
173                        let link = execution_context
174                            .definition
175                            .links
176                            .iter()
177                            .find(|l| l.request.id == request_id)
178                            .unwrap()
179                            .clone();
180                        // Create a new context for parallel execution
181                        let parallel_context = ChainExecutionContext {
182                            definition: execution_context.definition.clone(),
183                            templating: execution_context.templating.clone(),
184                            start_time: std::time::Instant::now(),
185                            config: execution_context.config.clone(),
186                        };
187
188                        let context = Arc::new(Mutex::new(parallel_context));
189                        let engine =
190                            ChainExecutionEngine::new(self.registry.clone(), self.config.clone());
191
192                        tokio::spawn(async move {
193                            let mut ctx = context.lock().await;
194                            engine.execute_request(&link, &mut ctx).await
195                        })
196                    })
197                    .collect::<Vec<_>>();
198
199                let results = join_all(tasks).await;
200                for result in results {
201                    result
202                        .map_err(|e| Error::generic(format!("Task join error: {}", e)))?
203                        .map_err(|e| Error::generic(format!("Request execution error: {}", e)))?;
204                }
205            }
206        }
207
208        Ok(())
209    }
210
211    /// Execute requests sequentially
212    async fn execute_sequential(
213        &self,
214        execution_context: &mut ChainExecutionContext,
215    ) -> Result<()> {
216        let links = execution_context.definition.links.clone();
217        for link in &links {
218            self.execute_request(link, execution_context).await?;
219        }
220        Ok(())
221    }
222
223    /// Execute a single request in the chain
224    async fn execute_request(
225        &self,
226        link: &ChainLink,
227        execution_context: &mut ChainExecutionContext,
228    ) -> Result<()> {
229        let request_start = std::time::Instant::now();
230
231        // Prepare the request with templating
232        execution_context.templating.set_current_request(link.request.clone());
233
234        let method = Method::from_bytes(link.request.method.as_bytes()).map_err(|e| {
235            Error::generic(format!("Invalid HTTP method '{}': {}", link.request.method, e))
236        })?;
237
238        let url = self.expand_template(&link.request.url, &execution_context.templating);
239
240        // Prepare headers
241        let mut headers = HeaderMap::new();
242        for (key, value) in &link.request.headers {
243            let expanded_value = self.expand_template(value, &execution_context.templating);
244            let header_name = HeaderName::from_str(key)
245                .map_err(|e| Error::generic(format!("Invalid header name '{}': {}", key, e)))?;
246            let header_value = HeaderValue::from_str(&expanded_value).map_err(|e| {
247                Error::generic(format!("Invalid header value for '{}': {}", key, e))
248            })?;
249            headers.insert(header_name, header_value);
250        }
251
252        // Prepare request builder
253        let mut request_builder = self.http_client.request(method, &url).headers(headers.clone());
254
255        // Add body if present
256        if let Some(body) = &link.request.body {
257            match body {
258                crate::request_chaining::RequestBody::Json(json_value) => {
259                    let expanded_body =
260                        self.expand_template_in_json(json_value, &execution_context.templating);
261                    request_builder = request_builder.json(&expanded_body);
262                }
263                crate::request_chaining::RequestBody::BinaryFile { path, content_type } => {
264                    // Create templating context for path expansion
265                    let templating_context =
266                        TemplatingContext::with_chain(execution_context.templating.clone());
267
268                    // Expand templates in the file path
269                    let expanded_path = expand_str_with_context(path, &templating_context);
270
271                    // Create a new body with expanded path
272                    let binary_body = crate::request_chaining::RequestBody::binary_file(
273                        expanded_path,
274                        content_type.clone(),
275                    );
276
277                    // Read the binary file
278                    match binary_body.to_bytes().await {
279                        Ok(file_bytes) => {
280                            request_builder = request_builder.body(file_bytes);
281
282                            // Set content type if specified
283                            if let Some(ct) = content_type {
284                                let mut headers = headers.clone();
285                                headers.insert(
286                                    "content-type",
287                                    ct.parse().unwrap_or_else(|_| {
288                                        reqwest::header::HeaderValue::from_static(
289                                            "application/octet-stream",
290                                        )
291                                    }),
292                                );
293                                request_builder = request_builder.headers(headers);
294                            }
295                        }
296                        Err(e) => {
297                            return Err(e);
298                        }
299                    }
300                }
301            }
302        }
303
304        // Set timeout if specified
305        if let Some(timeout_secs) = link.request.timeout_secs {
306            request_builder = request_builder.timeout(Duration::from_secs(timeout_secs));
307        }
308
309        // Execute the request
310        let response_result =
311            timeout(Duration::from_secs(self.config.global_timeout_secs), request_builder.send())
312                .await;
313
314        let response = match response_result {
315            Ok(Ok(resp)) => resp,
316            Ok(Err(e)) => {
317                return Err(Error::generic(format!("Request '{}' failed: {}", link.request.id, e)));
318            }
319            Err(_) => {
320                return Err(Error::generic(format!("Request '{}' timed out", link.request.id)));
321            }
322        };
323
324        let status = response.status();
325        let headers: HashMap<String, String> = response
326            .headers()
327            .iter()
328            .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
329            .collect();
330
331        let body_text = response.text().await.unwrap_or_default();
332        let body_json: Option<Value> = serde_json::from_str(&body_text).ok();
333
334        let duration_ms = request_start.elapsed().as_millis() as u64;
335        let executed_at = Utc::now().to_rfc3339();
336
337        let chain_response = ChainResponse {
338            status: status.as_u16(),
339            headers,
340            body: body_json,
341            duration_ms,
342            executed_at,
343            error: None,
344        };
345
346        // Validate expected status if specified
347        if let Some(expected) = &link.request.expected_status {
348            if !expected.contains(&status.as_u16()) {
349                let error_msg = format!(
350                    "Request '{}' returned status {} but expected one of {:?}",
351                    link.request.id,
352                    status.as_u16(),
353                    expected
354                );
355                return Err(Error::generic(error_msg));
356            }
357        }
358
359        // Store the response
360        if let Some(store_name) = &link.store_as {
361            execution_context
362                .templating
363                .chain_context
364                .store_response(store_name.clone(), chain_response.clone());
365        }
366
367        // Extract variables from response
368        for (var_name, extraction_path) in &link.extract {
369            if let Some(value) = self.extract_from_response(&chain_response, extraction_path) {
370                execution_context.templating.chain_context.set_variable(var_name.clone(), value);
371            }
372        }
373
374        // Also store by request ID as fallback
375        execution_context
376            .templating
377            .chain_context
378            .store_response(link.request.id.clone(), chain_response);
379
380        Ok(())
381    }
382
383    /// Build dependency graph from chain links
384    fn build_dependency_graph(&self, links: &[ChainLink]) -> HashMap<String, Vec<String>> {
385        let mut graph = HashMap::new();
386
387        for link in links {
388            graph
389                .entry(link.request.id.clone())
390                .or_insert_with(Vec::new)
391                .extend(link.request.depends_on.iter().cloned());
392        }
393
394        graph
395    }
396
397    /// Perform topological sort on dependency graph
398    fn topological_sort(&self, graph: &HashMap<String, Vec<String>>) -> Result<Vec<String>> {
399        let mut visited = HashSet::new();
400        let mut rec_stack = HashSet::new();
401        let mut result = Vec::new();
402
403        for node in graph.keys() {
404            if !visited.contains(node) {
405                self.topo_sort_util(node, graph, &mut visited, &mut rec_stack, &mut result)?;
406            }
407        }
408
409        result.reverse();
410        Ok(result)
411    }
412
413    /// Utility function for topological sort
414    #[allow(clippy::only_used_in_recursion)]
415    fn topo_sort_util(
416        &self,
417        node: &str,
418        graph: &HashMap<String, Vec<String>>,
419        visited: &mut HashSet<String>,
420        rec_stack: &mut HashSet<String>,
421        result: &mut Vec<String>,
422    ) -> Result<()> {
423        visited.insert(node.to_string());
424        rec_stack.insert(node.to_string());
425
426        if let Some(dependencies) = graph.get(node) {
427            for dep in dependencies {
428                if !visited.contains(dep) {
429                    self.topo_sort_util(dep, graph, visited, rec_stack, result)?;
430                } else if rec_stack.contains(dep) {
431                    return Err(Error::generic(format!(
432                        "Circular dependency detected involving '{}'",
433                        node
434                    )));
435                }
436            }
437        }
438
439        rec_stack.remove(node);
440        result.push(node.to_string());
441        Ok(())
442    }
443
444    /// Collect requests that can be executed in parallel (same dependency level)
445    fn collect_dependency_level(
446        &self,
447        request_id: String,
448        _graph: &HashMap<String, Vec<String>>,
449        level: &mut Vec<String>,
450        processed: &mut HashSet<String>,
451    ) {
452        level.push(request_id.clone());
453        processed.insert(request_id);
454    }
455
456    /// Expand template string with chain context
457    fn expand_template(&self, template: &str, context: &ChainTemplatingContext) -> String {
458        let templating_context = crate::templating::TemplatingContext {
459            chain_context: Some(context.clone()),
460            env_context: None,
461            virtual_clock: None,
462        };
463        crate::templating::expand_str_with_context(template, &templating_context)
464    }
465
466    /// Expand template variables in JSON value
467    fn expand_template_in_json(&self, value: &Value, context: &ChainTemplatingContext) -> Value {
468        match value {
469            Value::String(s) => Value::String(self.expand_template(s, context)),
470            Value::Array(arr) => {
471                Value::Array(arr.iter().map(|v| self.expand_template_in_json(v, context)).collect())
472            }
473            Value::Object(map) => {
474                let mut new_map = serde_json::Map::new();
475                for (k, v) in map {
476                    new_map.insert(
477                        self.expand_template(k, context),
478                        self.expand_template_in_json(v, context),
479                    );
480                }
481                Value::Object(new_map)
482            }
483            _ => value.clone(),
484        }
485    }
486
487    /// Extract value from response using JSON path-like syntax
488    fn extract_from_response(&self, response: &ChainResponse, path: &str) -> Option<Value> {
489        let parts: Vec<&str> = path.split('.').collect();
490
491        if parts.is_empty() || parts[0] != "body" {
492            return None;
493        }
494
495        let mut current = response.body.as_ref()?;
496
497        for part in &parts[1..] {
498            match current {
499                Value::Object(map) => {
500                    current = map.get(*part)?;
501                }
502                Value::Array(arr) => {
503                    if part.starts_with('[') && part.ends_with(']') {
504                        let index_str = &part[1..part.len() - 1];
505                        if let Ok(index) = index_str.parse::<usize>() {
506                            current = arr.get(index)?;
507                        } else {
508                            return None;
509                        }
510                    } else {
511                        return None;
512                    }
513                }
514                _ => return None,
515            }
516        }
517
518        Some(current.clone())
519    }
520}
521
522/// Result of chain execution
523#[derive(Debug, Clone)]
524pub struct ChainExecutionResult {
525    pub chain_id: String,
526    pub status: ChainExecutionStatus,
527    pub total_duration_ms: u64,
528    pub request_results: HashMap<String, ChainResponse>,
529    pub error_message: Option<String>,
530}
531
532/// Status of chain execution
533#[derive(Debug, Clone, PartialEq)]
534pub enum ChainExecutionStatus {
535    Successful,
536    PartialSuccess,
537    Failed,
538}
539
540#[cfg(test)]
541mod tests {
542    use super::*;
543    use std::sync::Arc;
544
545    #[tokio::test]
546    async fn test_engine_creation() {
547        let registry = Arc::new(RequestChainRegistry::new(ChainConfig::default()));
548        let _engine = ChainExecutionEngine::new(registry, ChainConfig::default());
549
550        // Engine should be created successfully
551    }
552
553    #[tokio::test]
554    async fn test_topological_sort() {
555        let registry = Arc::new(RequestChainRegistry::new(ChainConfig::default()));
556        let engine = ChainExecutionEngine::new(registry, ChainConfig::default());
557
558        let mut graph = HashMap::new();
559        graph.insert("A".to_string(), vec![]);
560        graph.insert("B".to_string(), vec!["A".to_string()]);
561        graph.insert("C".to_string(), vec!["A".to_string()]);
562        graph.insert("D".to_string(), vec!["B".to_string(), "C".to_string()]);
563
564        let topo_order = engine.topological_sort(&graph).unwrap();
565
566        // Verify this is a valid topological ordering
567        // D should come before B and C (its dependencies)
568        // B should come before A (its dependency)
569        // C should come before A (its dependency)
570        let d_pos = topo_order.iter().position(|x| x == "D").unwrap();
571        let b_pos = topo_order.iter().position(|x| x == "B").unwrap();
572        let c_pos = topo_order.iter().position(|x| x == "C").unwrap();
573        let a_pos = topo_order.iter().position(|x| x == "A").unwrap();
574
575        assert!(d_pos < b_pos, "D should come before B");
576        assert!(d_pos < c_pos, "D should come before C");
577        assert!(b_pos < a_pos, "B should come before A");
578        assert!(c_pos < a_pos, "C should come before A");
579        assert_eq!(topo_order.len(), 4, "Should have all 4 nodes");
580    }
581
582    #[tokio::test]
583    async fn test_circular_dependency_detection() {
584        let registry = Arc::new(RequestChainRegistry::new(ChainConfig::default()));
585        let engine = ChainExecutionEngine::new(registry, ChainConfig::default());
586
587        let mut graph = HashMap::new();
588        graph.insert("A".to_string(), vec!["B".to_string()]);
589        graph.insert("B".to_string(), vec!["A".to_string()]); // Circular dependency
590
591        let result = engine.topological_sort(&graph);
592        assert!(result.is_err());
593    }
594}