use std::{
collections::{HashMap, HashSet},
sync::Arc,
time::{Duration, Instant, SystemTime},
};
use tokio::task::JoinSet;
use http::Method;
use super::{Request, RequestExecution, ResponseSnapshot};
#[derive(Debug, Clone)]
pub struct ExecutionRecord {
pub index: usize,
pub description: String,
pub method: Method,
pub url: String,
pub execution: RequestExecution,
pub started_at: SystemTime,
pub duration: Duration,
}
#[derive(Debug, Clone)]
pub struct ExecutionOptions {
pub parallel: bool,
pub max_concurrency: Option<usize>,
pub continue_on_error: bool,
}
impl Default for ExecutionOptions {
fn default() -> Self {
Self {
parallel: false,
max_concurrency: None,
continue_on_error: false,
}
}
}
impl ExecutionOptions {
fn parallel_enabled(&self) -> bool {
self.parallel || self.max_concurrency.is_some()
}
fn concurrency_limit(&self) -> usize {
self.max_concurrency.unwrap_or(usize::MAX)
}
}
pub type ExecutionObserver = Arc<dyn Fn(ExecutionEvent) + Send + Sync>;
#[derive(Debug, Clone)]
pub enum ExecutionEvent {
RequestCompleted { record: ExecutionRecord },
RequestFailed { failure: RequestFailure },
AssertionPassed { request: String, assertion: String },
}
#[derive(Debug)]
pub struct ExecutionError {
failures: Vec<RequestFailure>,
completed: Vec<ExecutionRecord>,
}
impl ExecutionError {
pub fn new(failures: Vec<RequestFailure>, completed: Vec<ExecutionRecord>) -> Self {
Self {
failures,
completed,
}
}
pub fn from_failure(failure: RequestFailure, completed: Vec<ExecutionRecord>) -> Self {
Self::new(vec![failure], completed)
}
pub fn into_parts(self) -> (Vec<RequestFailure>, Vec<ExecutionRecord>) {
(self.failures, self.completed)
}
}
impl std::fmt::Display for ExecutionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.failures.is_empty() {
return write!(f, "Execution failed for unknown reason");
}
writeln!(f, "{}", self.failures[0])?;
for failure in self.failures.iter().skip(1) {
writeln!(f, " - {}", failure)?;
}
Ok(())
}
}
impl std::error::Error for ExecutionError {}
#[derive(Debug, Clone)]
pub struct RequestFailure {
index: usize,
request: String,
kind: RequestFailureKind,
}
impl RequestFailure {
fn execution(index: usize, request: String, message: String) -> Self {
Self {
index,
request,
kind: RequestFailureKind::Execution { message },
}
}
fn dependency(index: usize, request: String, dependency: String) -> Self {
Self {
index,
request,
kind: RequestFailureKind::Dependency { dependency },
}
}
fn missing_dependency(index: usize, request: String, dependency: String) -> Self {
Self {
index,
request,
kind: RequestFailureKind::MissingDependency { dependency },
}
}
fn join_error(message: String) -> Self {
Self {
index: usize::MAX,
request: "<unknown>".to_string(),
kind: RequestFailureKind::Join { message },
}
}
pub fn index(&self) -> Option<usize> {
if self.index == usize::MAX {
None
} else {
Some(self.index)
}
}
pub fn request(&self) -> &str {
&self.request
}
pub fn kind(&self) -> &RequestFailureKind {
&self.kind
}
}
impl std::fmt::Display for RequestFailure {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.kind {
RequestFailureKind::Execution { message } => write!(
f,
"Request '{}' (index {}) failed during execution: {}",
self.request, self.index, message
),
RequestFailureKind::Dependency { dependency } => write!(
f,
"Request '{}' (index {}) skipped because dependency '{}' failed",
self.request, self.index, dependency
),
RequestFailureKind::MissingDependency { dependency } => write!(
f,
"Request '{}' (index {}) expected dependency '{}' to complete, but no execution record was found",
self.request, self.index, dependency
),
RequestFailureKind::Join { message } => write!(
f,
"Background task failed unexpectedly: {}",
message
),
}
}
}
#[derive(Debug, Clone)]
pub enum RequestFailureKind {
Execution { message: String },
Dependency { dependency: String },
MissingDependency { dependency: String },
Join { message: String },
}
pub async fn execute_plan(
requests: &[Request],
order: &[usize],
options: ExecutionOptions,
) -> Result<Vec<ExecutionRecord>, ExecutionError> {
execute_plan_with_observer(requests, order, options, None).await
}
pub async fn execute_plan_with_observer(
requests: &[Request],
order: &[usize],
options: ExecutionOptions,
observer: Option<ExecutionObserver>,
) -> Result<Vec<ExecutionRecord>, ExecutionError> {
let observer_ref = observer.as_ref();
if options.parallel_enabled() {
execute_parallel(requests, order, &options, observer_ref).await
} else {
execute_sequential(requests, order, &options, observer_ref).await
}
}
async fn execute_sequential(
requests: &[Request],
order: &[usize],
options: &ExecutionOptions,
observer: Option<&ExecutionObserver>,
) -> Result<Vec<ExecutionRecord>, ExecutionError> {
let mut results: HashMap<String, RequestExecution> = HashMap::new();
let mut records: Vec<ExecutionRecord> = Vec::new();
let mut failures: Vec<RequestFailure> = Vec::new();
let mut failed_requests: HashSet<String> = HashSet::new();
for &idx in order {
let request = requests
.get(idx)
.expect("planner produced invalid request index");
if let Some(dep) = request
.dependencies
.iter()
.find(|dep| failed_requests.contains(*dep))
{
let failure = RequestFailure::dependency(idx, request.description.clone(), dep.clone());
emit_event(
observer,
ExecutionEvent::RequestFailed {
failure: failure.clone(),
},
);
failed_requests.insert(request.description.clone());
if options.continue_on_error {
failures.push(failure);
continue;
} else {
records.sort_by_key(|record| record.index);
return Err(ExecutionError::from_failure(failure, records));
}
}
let (inherited_context, dependency_snapshots) =
match build_dependency_context(request, &results) {
Ok(ctx) => ctx,
Err(missing) => {
let failure = RequestFailure::missing_dependency(
idx,
request.description.clone(),
missing,
);
emit_event(
observer,
ExecutionEvent::RequestFailed {
failure: failure.clone(),
},
);
failed_requests.insert(request.description.clone());
if options.continue_on_error {
failures.push(failure);
continue;
} else {
records.sort_by_key(|record| record.index);
return Err(ExecutionError::from_failure(failure, records));
}
}
};
let started_at = SystemTime::now();
let timer = Instant::now();
match request
.exec(&inherited_context, &dependency_snapshots, observer)
.await
{
Ok(execution) => {
let duration = timer.elapsed();
results.insert(request.description.clone(), execution.clone());
let record = ExecutionRecord {
index: idx,
description: request.description.clone(),
method: request.method.clone(),
url: request.url.clone(),
execution,
started_at,
duration,
};
emit_event(
observer,
ExecutionEvent::RequestCompleted {
record: record.clone(),
},
);
records.push(record);
}
Err(err) => {
let failure =
RequestFailure::execution(idx, request.description.clone(), err.to_string());
emit_event(
observer,
ExecutionEvent::RequestFailed {
failure: failure.clone(),
},
);
failed_requests.insert(request.description.clone());
if options.continue_on_error {
failures.push(failure);
continue;
} else {
records.sort_by_key(|record| record.index);
return Err(ExecutionError::from_failure(failure, records));
}
}
}
}
records.sort_by_key(|record| record.index);
if failures.is_empty() {
Ok(records)
} else {
Err(ExecutionError::new(failures, records))
}
}
async fn execute_parallel(
requests: &[Request],
order: &[usize],
options: &ExecutionOptions,
observer: Option<&ExecutionObserver>,
) -> Result<Vec<ExecutionRecord>, ExecutionError> {
let mut results: HashMap<String, RequestExecution> = HashMap::new();
let mut records: Vec<ExecutionRecord> = Vec::new();
let mut failures: Vec<RequestFailure> = Vec::new();
let mut failed_requests: HashSet<String> = HashSet::new();
let mut completed: HashSet<usize> = HashSet::new();
let mut in_flight: HashSet<usize> = HashSet::new();
let total = order.len();
let limit = options.concurrency_limit().max(1);
let mut join_set: JoinSet<(
usize,
String,
Method,
String,
SystemTime,
Duration,
Result<RequestExecution, String>,
)> = JoinSet::new();
while completed.len() < total {
let mut scheduled_this_round = false;
for &idx in order {
if completed.contains(&idx) || in_flight.contains(&idx) {
continue;
}
let request = requests
.get(idx)
.expect("planner produced invalid request index");
if let Some(dep) = request
.dependencies
.iter()
.find(|dep| failed_requests.contains(*dep))
{
let failure =
RequestFailure::dependency(idx, request.description.clone(), dep.clone());
emit_event(
observer,
ExecutionEvent::RequestFailed {
failure: failure.clone(),
},
);
failed_requests.insert(request.description.clone());
completed.insert(idx);
if options.continue_on_error {
failures.push(failure);
continue;
} else {
failures.push(failure);
records.sort_by_key(|record| record.index);
join_set.abort_all();
return Err(ExecutionError::new(failures, records));
}
}
if !request
.dependencies
.iter()
.all(|dep| results.contains_key(dep))
{
continue;
}
if in_flight.len() >= limit {
break;
}
let (inherited_context, dependency_snapshots) =
match build_dependency_context(request, &results) {
Ok(ctx) => ctx,
Err(missing) => {
let failure = RequestFailure::missing_dependency(
idx,
request.description.clone(),
missing,
);
emit_event(
observer,
ExecutionEvent::RequestFailed {
failure: failure.clone(),
},
);
failures.push(failure);
records.sort_by_key(|record| record.index);
join_set.abort_all();
return Err(ExecutionError::new(failures, records));
}
};
let observer_for_exec = observer.cloned();
let req_clone = request.clone();
join_set.spawn(async move {
let started_at = SystemTime::now();
let timer = Instant::now();
let desc = req_clone.description.clone();
let method = req_clone.method.clone();
let url = req_clone.url.clone();
let result = req_clone
.exec(
&inherited_context,
&dependency_snapshots,
observer_for_exec.as_ref(),
)
.await
.map_err(|err| err.to_string());
let duration = timer.elapsed();
(idx, desc, method, url, started_at, duration, result)
});
in_flight.insert(idx);
scheduled_this_round = true;
}
if in_flight.is_empty() {
if !scheduled_this_round {
break;
}
continue;
}
if let Some(join_result) = join_set.join_next().await {
match join_result {
Ok((idx, desc, method, url, started_at, duration, outcome)) => {
in_flight.remove(&idx);
completed.insert(idx);
match outcome {
Ok(execution) => {
results.insert(desc.clone(), execution.clone());
let record = ExecutionRecord {
index: idx,
description: desc,
method,
url,
execution,
started_at,
duration,
};
emit_event(
observer,
ExecutionEvent::RequestCompleted {
record: record.clone(),
},
);
records.push(record);
}
Err(message) => {
failed_requests.insert(desc.clone());
let failure = RequestFailure::execution(idx, desc.clone(), message);
emit_event(
observer,
ExecutionEvent::RequestFailed {
failure: failure.clone(),
},
);
if options.continue_on_error {
failures.push(failure);
} else {
failures.push(failure);
records.sort_by_key(|record| record.index);
join_set.abort_all();
return Err(ExecutionError::new(failures, records));
}
}
}
}
Err(err) => {
let failure = RequestFailure::join_error(err.to_string());
emit_event(
observer,
ExecutionEvent::RequestFailed {
failure: failure.clone(),
},
);
failures.push(failure);
records.sort_by_key(|record| record.index);
join_set.abort_all();
return Err(ExecutionError::new(failures, records));
}
}
}
}
records.sort_by_key(|record| record.index);
if failures.is_empty() {
Ok(records)
} else {
Err(ExecutionError::new(failures, records))
}
}
fn emit_event(observer: Option<&ExecutionObserver>, event: ExecutionEvent) {
if let Some(callback) = observer {
callback(event);
}
}
fn build_dependency_context(
request: &Request,
results: &HashMap<String, RequestExecution>,
) -> Result<(HashMap<String, String>, HashMap<String, ResponseSnapshot>), String> {
let mut inherited_context: HashMap<String, String> = HashMap::new();
let mut dependency_snapshots: HashMap<String, ResponseSnapshot> = HashMap::new();
for dependency in &request.dependencies {
let dep_exec = results.get(dependency).ok_or_else(|| dependency.clone())?;
for (key, value) in &dep_exec.export_env {
inherited_context.insert(key.clone(), value.clone());
}
dependency_snapshots.insert(dependency.clone(), dep_exec.snapshot.clone());
}
Ok((inherited_context, dependency_snapshots))
}